1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(str) => Some(str),
885 LanguageModelToolResultContent::Image(_) => {
886 // TODO: We should display image
887 None
888 }
889 }
890 }
891
892 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
893 self.tool_use.tool_result_card(id).cloned()
894 }
895
896 /// Return tools that are both enabled and supported by the model
897 pub fn available_tools(
898 &self,
899 cx: &App,
900 model: Arc<dyn LanguageModel>,
901 ) -> Vec<LanguageModelRequestTool> {
902 if model.supports_tools() {
903 self.tools()
904 .read(cx)
905 .enabled_tools(cx)
906 .into_iter()
907 .filter_map(|tool| {
908 // Skip tools that cannot be supported
909 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
910 Some(LanguageModelRequestTool {
911 name: tool.name(),
912 description: tool.description(),
913 input_schema,
914 })
915 })
916 .collect()
917 } else {
918 Vec::default()
919 }
920 }
921
922 pub fn insert_user_message(
923 &mut self,
924 text: impl Into<String>,
925 loaded_context: ContextLoadResult,
926 git_checkpoint: Option<GitStoreCheckpoint>,
927 creases: Vec<MessageCrease>,
928 cx: &mut Context<Self>,
929 ) -> MessageId {
930 if !loaded_context.referenced_buffers.is_empty() {
931 self.action_log.update(cx, |log, cx| {
932 for buffer in loaded_context.referenced_buffers {
933 log.buffer_read(buffer, cx);
934 }
935 });
936 }
937
938 let message_id = self.insert_message(
939 Role::User,
940 vec![MessageSegment::Text(text.into())],
941 loaded_context.loaded_context,
942 creases,
943 cx,
944 );
945
946 if let Some(git_checkpoint) = git_checkpoint {
947 self.pending_checkpoint = Some(ThreadCheckpoint {
948 message_id,
949 git_checkpoint,
950 });
951 }
952
953 self.auto_capture_telemetry(cx);
954
955 message_id
956 }
957
958 pub fn insert_assistant_message(
959 &mut self,
960 segments: Vec<MessageSegment>,
961 cx: &mut Context<Self>,
962 ) -> MessageId {
963 self.insert_message(
964 Role::Assistant,
965 segments,
966 LoadedContext::default(),
967 Vec::new(),
968 cx,
969 )
970 }
971
972 pub fn insert_message(
973 &mut self,
974 role: Role,
975 segments: Vec<MessageSegment>,
976 loaded_context: LoadedContext,
977 creases: Vec<MessageCrease>,
978 cx: &mut Context<Self>,
979 ) -> MessageId {
980 let id = self.next_message_id.post_inc();
981 self.messages.push(Message {
982 id,
983 role,
984 segments,
985 loaded_context,
986 creases,
987 });
988 self.touch_updated_at();
989 cx.emit(ThreadEvent::MessageAdded(id));
990 id
991 }
992
993 pub fn edit_message(
994 &mut self,
995 id: MessageId,
996 new_role: Role,
997 new_segments: Vec<MessageSegment>,
998 loaded_context: Option<LoadedContext>,
999 checkpoint: Option<GitStoreCheckpoint>,
1000 cx: &mut Context<Self>,
1001 ) -> bool {
1002 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1003 return false;
1004 };
1005 message.role = new_role;
1006 message.segments = new_segments;
1007 if let Some(context) = loaded_context {
1008 message.loaded_context = context;
1009 }
1010 if let Some(git_checkpoint) = checkpoint {
1011 self.checkpoints_by_message.insert(
1012 id,
1013 ThreadCheckpoint {
1014 message_id: id,
1015 git_checkpoint,
1016 },
1017 );
1018 }
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageEdited(id));
1021 true
1022 }
1023
1024 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1025 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1026 return false;
1027 };
1028 self.messages.remove(index);
1029 self.touch_updated_at();
1030 cx.emit(ThreadEvent::MessageDeleted(id));
1031 true
1032 }
1033
1034 /// Returns the representation of this [`Thread`] in a textual form.
1035 ///
1036 /// This is the representation we use when attaching a thread as context to another thread.
1037 pub fn text(&self) -> String {
1038 let mut text = String::new();
1039
1040 for message in &self.messages {
1041 text.push_str(match message.role {
1042 language_model::Role::User => "User:",
1043 language_model::Role::Assistant => "Agent:",
1044 language_model::Role::System => "System:",
1045 });
1046 text.push('\n');
1047
1048 for segment in &message.segments {
1049 match segment {
1050 MessageSegment::Text(content) => text.push_str(content),
1051 MessageSegment::Thinking { text: content, .. } => {
1052 text.push_str(&format!("<think>{}</think>", content))
1053 }
1054 MessageSegment::RedactedThinking(_) => {}
1055 }
1056 }
1057 text.push('\n');
1058 }
1059
1060 text
1061 }
1062
1063 /// Serializes this thread into a format for storage or telemetry.
1064 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1065 let initial_project_snapshot = self.initial_project_snapshot.clone();
1066 cx.spawn(async move |this, cx| {
1067 let initial_project_snapshot = initial_project_snapshot.await;
1068 this.read_with(cx, |this, cx| SerializedThread {
1069 version: SerializedThread::VERSION.to_string(),
1070 summary: this.summary().or_default(),
1071 updated_at: this.updated_at(),
1072 messages: this
1073 .messages()
1074 .map(|message| SerializedMessage {
1075 id: message.id,
1076 role: message.role,
1077 segments: message
1078 .segments
1079 .iter()
1080 .map(|segment| match segment {
1081 MessageSegment::Text(text) => {
1082 SerializedMessageSegment::Text { text: text.clone() }
1083 }
1084 MessageSegment::Thinking { text, signature } => {
1085 SerializedMessageSegment::Thinking {
1086 text: text.clone(),
1087 signature: signature.clone(),
1088 }
1089 }
1090 MessageSegment::RedactedThinking(data) => {
1091 SerializedMessageSegment::RedactedThinking {
1092 data: data.clone(),
1093 }
1094 }
1095 })
1096 .collect(),
1097 tool_uses: this
1098 .tool_uses_for_message(message.id, cx)
1099 .into_iter()
1100 .map(|tool_use| SerializedToolUse {
1101 id: tool_use.id,
1102 name: tool_use.name,
1103 input: tool_use.input,
1104 })
1105 .collect(),
1106 tool_results: this
1107 .tool_results_for_message(message.id)
1108 .into_iter()
1109 .map(|tool_result| SerializedToolResult {
1110 tool_use_id: tool_result.tool_use_id.clone(),
1111 is_error: tool_result.is_error,
1112 content: tool_result.content.clone(),
1113 output: tool_result.output.clone(),
1114 })
1115 .collect(),
1116 context: message.loaded_context.text.clone(),
1117 creases: message
1118 .creases
1119 .iter()
1120 .map(|crease| SerializedCrease {
1121 start: crease.range.start,
1122 end: crease.range.end,
1123 icon_path: crease.metadata.icon_path.clone(),
1124 label: crease.metadata.label.clone(),
1125 })
1126 .collect(),
1127 })
1128 .collect(),
1129 initial_project_snapshot,
1130 cumulative_token_usage: this.cumulative_token_usage,
1131 request_token_usage: this.request_token_usage.clone(),
1132 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1133 exceeded_window_error: this.exceeded_window_error.clone(),
1134 model: this
1135 .configured_model
1136 .as_ref()
1137 .map(|model| SerializedLanguageModel {
1138 provider: model.provider.id().0.to_string(),
1139 model: model.model.id().0.to_string(),
1140 }),
1141 completion_mode: Some(this.completion_mode),
1142 })
1143 })
1144 }
1145
1146 pub fn remaining_turns(&self) -> u32 {
1147 self.remaining_turns
1148 }
1149
1150 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1151 self.remaining_turns = remaining_turns;
1152 }
1153
1154 pub fn send_to_model(
1155 &mut self,
1156 model: Arc<dyn LanguageModel>,
1157 window: Option<AnyWindowHandle>,
1158 cx: &mut Context<Self>,
1159 ) {
1160 if self.remaining_turns == 0 {
1161 return;
1162 }
1163
1164 self.remaining_turns -= 1;
1165
1166 let request = self.to_completion_request(model.clone(), cx);
1167
1168 self.stream_completion(request, model, window, cx);
1169 }
1170
1171 pub fn used_tools_since_last_user_message(&self) -> bool {
1172 for message in self.messages.iter().rev() {
1173 if self.tool_use.message_has_tool_results(message.id) {
1174 return true;
1175 } else if message.role == Role::User {
1176 return false;
1177 }
1178 }
1179
1180 false
1181 }
1182
1183 pub fn to_completion_request(
1184 &self,
1185 model: Arc<dyn LanguageModel>,
1186 cx: &mut Context<Self>,
1187 ) -> LanguageModelRequest {
1188 let mut request = LanguageModelRequest {
1189 thread_id: Some(self.id.to_string()),
1190 prompt_id: Some(self.last_prompt_id.to_string()),
1191 mode: None,
1192 messages: vec![],
1193 tools: Vec::new(),
1194 tool_choice: None,
1195 stop: Vec::new(),
1196 temperature: AssistantSettings::temperature_for_model(&model, cx),
1197 };
1198
1199 let available_tools = self.available_tools(cx, model.clone());
1200 let available_tool_names = available_tools
1201 .iter()
1202 .map(|tool| tool.name.clone())
1203 .collect();
1204
1205 let model_context = &ModelContext {
1206 available_tools: available_tool_names,
1207 };
1208
1209 if let Some(project_context) = self.project_context.borrow().as_ref() {
1210 match self
1211 .prompt_builder
1212 .generate_assistant_system_prompt(project_context, model_context)
1213 {
1214 Err(err) => {
1215 let message = format!("{err:?}").into();
1216 log::error!("{message}");
1217 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1218 header: "Error generating system prompt".into(),
1219 message,
1220 }));
1221 }
1222 Ok(system_prompt) => {
1223 request.messages.push(LanguageModelRequestMessage {
1224 role: Role::System,
1225 content: vec![MessageContent::Text(system_prompt)],
1226 cache: true,
1227 });
1228 }
1229 }
1230 } else {
1231 let message = "Context for system prompt unexpectedly not ready.".into();
1232 log::error!("{message}");
1233 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1234 header: "Error generating system prompt".into(),
1235 message,
1236 }));
1237 }
1238
1239 let mut message_ix_to_cache = None;
1240 for message in &self.messages {
1241 let mut request_message = LanguageModelRequestMessage {
1242 role: message.role,
1243 content: Vec::new(),
1244 cache: false,
1245 };
1246
1247 message
1248 .loaded_context
1249 .add_to_request_message(&mut request_message);
1250
1251 for segment in &message.segments {
1252 match segment {
1253 MessageSegment::Text(text) => {
1254 if !text.is_empty() {
1255 request_message
1256 .content
1257 .push(MessageContent::Text(text.into()));
1258 }
1259 }
1260 MessageSegment::Thinking { text, signature } => {
1261 if !text.is_empty() {
1262 request_message.content.push(MessageContent::Thinking {
1263 text: text.into(),
1264 signature: signature.clone(),
1265 });
1266 }
1267 }
1268 MessageSegment::RedactedThinking(data) => {
1269 request_message
1270 .content
1271 .push(MessageContent::RedactedThinking(data.clone()));
1272 }
1273 };
1274 }
1275
1276 let mut cache_message = true;
1277 let mut tool_results_message = LanguageModelRequestMessage {
1278 role: Role::User,
1279 content: Vec::new(),
1280 cache: false,
1281 };
1282 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1283 if let Some(tool_result) = tool_result {
1284 request_message
1285 .content
1286 .push(MessageContent::ToolUse(tool_use.clone()));
1287 tool_results_message
1288 .content
1289 .push(MessageContent::ToolResult(LanguageModelToolResult {
1290 tool_use_id: tool_use.id.clone(),
1291 tool_name: tool_result.tool_name.clone(),
1292 is_error: tool_result.is_error,
1293 content: if tool_result.content.is_empty() {
1294 // Surprisingly, the API fails if we return an empty string here.
1295 // It thinks we are sending a tool use without a tool result.
1296 "<Tool returned an empty string>".into()
1297 } else {
1298 tool_result.content.clone()
1299 },
1300 output: None,
1301 }));
1302 } else {
1303 cache_message = false;
1304 log::debug!(
1305 "skipped tool use {:?} because it is still pending",
1306 tool_use
1307 );
1308 }
1309 }
1310
1311 if cache_message {
1312 message_ix_to_cache = Some(request.messages.len());
1313 }
1314 request.messages.push(request_message);
1315
1316 if !tool_results_message.content.is_empty() {
1317 if cache_message {
1318 message_ix_to_cache = Some(request.messages.len());
1319 }
1320 request.messages.push(tool_results_message);
1321 }
1322 }
1323
1324 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1325 if let Some(message_ix_to_cache) = message_ix_to_cache {
1326 request.messages[message_ix_to_cache].cache = true;
1327 }
1328
1329 self.attached_tracked_files_state(&mut request.messages, cx);
1330
1331 request.tools = available_tools;
1332 request.mode = if model.supports_max_mode() {
1333 Some(self.completion_mode.into())
1334 } else {
1335 Some(CompletionMode::Normal.into())
1336 };
1337
1338 request
1339 }
1340
1341 fn to_summarize_request(
1342 &self,
1343 model: &Arc<dyn LanguageModel>,
1344 added_user_message: String,
1345 cx: &App,
1346 ) -> LanguageModelRequest {
1347 let mut request = LanguageModelRequest {
1348 thread_id: None,
1349 prompt_id: None,
1350 mode: None,
1351 messages: vec![],
1352 tools: Vec::new(),
1353 tool_choice: None,
1354 stop: Vec::new(),
1355 temperature: AssistantSettings::temperature_for_model(model, cx),
1356 };
1357
1358 for message in &self.messages {
1359 let mut request_message = LanguageModelRequestMessage {
1360 role: message.role,
1361 content: Vec::new(),
1362 cache: false,
1363 };
1364
1365 for segment in &message.segments {
1366 match segment {
1367 MessageSegment::Text(text) => request_message
1368 .content
1369 .push(MessageContent::Text(text.clone())),
1370 MessageSegment::Thinking { .. } => {}
1371 MessageSegment::RedactedThinking(_) => {}
1372 }
1373 }
1374
1375 if request_message.content.is_empty() {
1376 continue;
1377 }
1378
1379 request.messages.push(request_message);
1380 }
1381
1382 request.messages.push(LanguageModelRequestMessage {
1383 role: Role::User,
1384 content: vec![MessageContent::Text(added_user_message)],
1385 cache: false,
1386 });
1387
1388 request
1389 }
1390
1391 fn attached_tracked_files_state(
1392 &self,
1393 messages: &mut Vec<LanguageModelRequestMessage>,
1394 cx: &App,
1395 ) {
1396 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1397
1398 let mut stale_message = String::new();
1399
1400 let action_log = self.action_log.read(cx);
1401
1402 for stale_file in action_log.stale_buffers(cx) {
1403 let Some(file) = stale_file.read(cx).file() else {
1404 continue;
1405 };
1406
1407 if stale_message.is_empty() {
1408 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1409 }
1410
1411 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1412 }
1413
1414 let mut content = Vec::with_capacity(2);
1415
1416 if !stale_message.is_empty() {
1417 content.push(stale_message.into());
1418 }
1419
1420 if !content.is_empty() {
1421 let context_message = LanguageModelRequestMessage {
1422 role: Role::User,
1423 content,
1424 cache: false,
1425 };
1426
1427 messages.push(context_message);
1428 }
1429 }
1430
1431 pub fn stream_completion(
1432 &mut self,
1433 request: LanguageModelRequest,
1434 model: Arc<dyn LanguageModel>,
1435 window: Option<AnyWindowHandle>,
1436 cx: &mut Context<Self>,
1437 ) {
1438 self.tool_use_limit_reached = false;
1439
1440 let pending_completion_id = post_inc(&mut self.completion_count);
1441 let mut request_callback_parameters = if self.request_callback.is_some() {
1442 Some((request.clone(), Vec::new()))
1443 } else {
1444 None
1445 };
1446 let prompt_id = self.last_prompt_id.clone();
1447 let tool_use_metadata = ToolUseMetadata {
1448 model: model.clone(),
1449 thread_id: self.id.clone(),
1450 prompt_id: prompt_id.clone(),
1451 };
1452
1453 self.last_received_chunk_at = Some(Instant::now());
1454
1455 let task = cx.spawn(async move |thread, cx| {
1456 let stream_completion_future = model.stream_completion(request, &cx);
1457 let initial_token_usage =
1458 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1459 let stream_completion = async {
1460 let mut events = stream_completion_future.await?;
1461
1462 let mut stop_reason = StopReason::EndTurn;
1463 let mut current_token_usage = TokenUsage::default();
1464
1465 thread
1466 .update(cx, |_thread, cx| {
1467 cx.emit(ThreadEvent::NewRequest);
1468 })
1469 .ok();
1470
1471 let mut request_assistant_message_id = None;
1472
1473 while let Some(event) = events.next().await {
1474 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1475 response_events
1476 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1477 }
1478
1479 thread.update(cx, |thread, cx| {
1480 let event = match event {
1481 Ok(event) => event,
1482 Err(LanguageModelCompletionError::BadInputJson {
1483 id,
1484 tool_name,
1485 raw_input: invalid_input_json,
1486 json_parse_error,
1487 }) => {
1488 thread.receive_invalid_tool_json(
1489 id,
1490 tool_name,
1491 invalid_input_json,
1492 json_parse_error,
1493 window,
1494 cx,
1495 );
1496 return Ok(());
1497 }
1498 Err(LanguageModelCompletionError::Other(error)) => {
1499 return Err(error);
1500 }
1501 };
1502
1503 match event {
1504 LanguageModelCompletionEvent::StartMessage { .. } => {
1505 request_assistant_message_id =
1506 Some(thread.insert_assistant_message(
1507 vec![MessageSegment::Text(String::new())],
1508 cx,
1509 ));
1510 }
1511 LanguageModelCompletionEvent::Stop(reason) => {
1512 stop_reason = reason;
1513 }
1514 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1515 thread.update_token_usage_at_last_message(token_usage);
1516 thread.cumulative_token_usage = thread.cumulative_token_usage
1517 + token_usage
1518 - current_token_usage;
1519 current_token_usage = token_usage;
1520 }
1521 LanguageModelCompletionEvent::Text(chunk) => {
1522 thread.received_chunk();
1523
1524 cx.emit(ThreadEvent::ReceivedTextChunk);
1525 if let Some(last_message) = thread.messages.last_mut() {
1526 if last_message.role == Role::Assistant
1527 && !thread.tool_use.has_tool_results(last_message.id)
1528 {
1529 last_message.push_text(&chunk);
1530 cx.emit(ThreadEvent::StreamedAssistantText(
1531 last_message.id,
1532 chunk,
1533 ));
1534 } else {
1535 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1536 // of a new Assistant response.
1537 //
1538 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1539 // will result in duplicating the text of the chunk in the rendered Markdown.
1540 request_assistant_message_id =
1541 Some(thread.insert_assistant_message(
1542 vec![MessageSegment::Text(chunk.to_string())],
1543 cx,
1544 ));
1545 };
1546 }
1547 }
1548 LanguageModelCompletionEvent::Thinking {
1549 text: chunk,
1550 signature,
1551 } => {
1552 thread.received_chunk();
1553
1554 if let Some(last_message) = thread.messages.last_mut() {
1555 if last_message.role == Role::Assistant
1556 && !thread.tool_use.has_tool_results(last_message.id)
1557 {
1558 last_message.push_thinking(&chunk, signature);
1559 cx.emit(ThreadEvent::StreamedAssistantThinking(
1560 last_message.id,
1561 chunk,
1562 ));
1563 } else {
1564 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1565 // of a new Assistant response.
1566 //
1567 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1568 // will result in duplicating the text of the chunk in the rendered Markdown.
1569 request_assistant_message_id =
1570 Some(thread.insert_assistant_message(
1571 vec![MessageSegment::Thinking {
1572 text: chunk.to_string(),
1573 signature,
1574 }],
1575 cx,
1576 ));
1577 };
1578 }
1579 }
1580 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1581 let last_assistant_message_id = request_assistant_message_id
1582 .unwrap_or_else(|| {
1583 let new_assistant_message_id =
1584 thread.insert_assistant_message(vec![], cx);
1585 request_assistant_message_id =
1586 Some(new_assistant_message_id);
1587 new_assistant_message_id
1588 });
1589
1590 let tool_use_id = tool_use.id.clone();
1591 let streamed_input = if tool_use.is_input_complete {
1592 None
1593 } else {
1594 Some((&tool_use.input).clone())
1595 };
1596
1597 let ui_text = thread.tool_use.request_tool_use(
1598 last_assistant_message_id,
1599 tool_use,
1600 tool_use_metadata.clone(),
1601 cx,
1602 );
1603
1604 if let Some(input) = streamed_input {
1605 cx.emit(ThreadEvent::StreamedToolUse {
1606 tool_use_id,
1607 ui_text,
1608 input,
1609 });
1610 }
1611 }
1612 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1613 if let Some(completion) = thread
1614 .pending_completions
1615 .iter_mut()
1616 .find(|completion| completion.id == pending_completion_id)
1617 {
1618 match status_update {
1619 CompletionRequestStatus::Queued {
1620 position,
1621 } => {
1622 completion.queue_state = QueueState::Queued { position };
1623 }
1624 CompletionRequestStatus::Started => {
1625 completion.queue_state = QueueState::Started;
1626 }
1627 CompletionRequestStatus::Failed {
1628 code, message, request_id
1629 } => {
1630 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1631 }
1632 CompletionRequestStatus::UsageUpdated {
1633 amount, limit
1634 } => {
1635 let usage = RequestUsage { limit, amount: amount as i32 };
1636
1637 thread.last_usage = Some(usage);
1638 }
1639 CompletionRequestStatus::ToolUseLimitReached => {
1640 thread.tool_use_limit_reached = true;
1641 }
1642 }
1643 }
1644 }
1645 }
1646
1647 thread.touch_updated_at();
1648 cx.emit(ThreadEvent::StreamedCompletion);
1649 cx.notify();
1650
1651 thread.auto_capture_telemetry(cx);
1652 Ok(())
1653 })??;
1654
1655 smol::future::yield_now().await;
1656 }
1657
1658 thread.update(cx, |thread, cx| {
1659 thread.last_received_chunk_at = None;
1660 thread
1661 .pending_completions
1662 .retain(|completion| completion.id != pending_completion_id);
1663
1664 // If there is a response without tool use, summarize the message. Otherwise,
1665 // allow two tool uses before summarizing.
1666 if matches!(thread.summary, ThreadSummary::Pending)
1667 && thread.messages.len() >= 2
1668 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1669 {
1670 thread.summarize(cx);
1671 }
1672 })?;
1673
1674 anyhow::Ok(stop_reason)
1675 };
1676
1677 let result = stream_completion.await;
1678
1679 thread
1680 .update(cx, |thread, cx| {
1681 thread.finalize_pending_checkpoint(cx);
1682 match result.as_ref() {
1683 Ok(stop_reason) => match stop_reason {
1684 StopReason::ToolUse => {
1685 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1686 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1687 }
1688 StopReason::EndTurn | StopReason::MaxTokens => {
1689 thread.project.update(cx, |project, cx| {
1690 project.set_agent_location(None, cx);
1691 });
1692 }
1693 },
1694 Err(error) => {
1695 thread.project.update(cx, |project, cx| {
1696 project.set_agent_location(None, cx);
1697 });
1698
1699 if error.is::<PaymentRequiredError>() {
1700 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1701 } else if let Some(error) =
1702 error.downcast_ref::<ModelRequestLimitReachedError>()
1703 {
1704 cx.emit(ThreadEvent::ShowError(
1705 ThreadError::ModelRequestLimitReached { plan: error.plan },
1706 ));
1707 } else if let Some(known_error) =
1708 error.downcast_ref::<LanguageModelKnownError>()
1709 {
1710 match known_error {
1711 LanguageModelKnownError::ContextWindowLimitExceeded {
1712 tokens,
1713 } => {
1714 thread.exceeded_window_error = Some(ExceededWindowError {
1715 model_id: model.id(),
1716 token_count: *tokens,
1717 });
1718 cx.notify();
1719 }
1720 }
1721 } else {
1722 let error_message = error
1723 .chain()
1724 .map(|err| err.to_string())
1725 .collect::<Vec<_>>()
1726 .join("\n");
1727 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1728 header: "Error interacting with language model".into(),
1729 message: SharedString::from(error_message.clone()),
1730 }));
1731 }
1732
1733 thread.cancel_last_completion(window, cx);
1734 }
1735 }
1736 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1737
1738 if let Some((request_callback, (request, response_events))) = thread
1739 .request_callback
1740 .as_mut()
1741 .zip(request_callback_parameters.as_ref())
1742 {
1743 request_callback(request, response_events);
1744 }
1745
1746 thread.auto_capture_telemetry(cx);
1747
1748 if let Ok(initial_usage) = initial_token_usage {
1749 let usage = thread.cumulative_token_usage - initial_usage;
1750
1751 telemetry::event!(
1752 "Assistant Thread Completion",
1753 thread_id = thread.id().to_string(),
1754 prompt_id = prompt_id,
1755 model = model.telemetry_id(),
1756 model_provider = model.provider_id().to_string(),
1757 input_tokens = usage.input_tokens,
1758 output_tokens = usage.output_tokens,
1759 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1760 cache_read_input_tokens = usage.cache_read_input_tokens,
1761 );
1762 }
1763 })
1764 .ok();
1765 });
1766
1767 self.pending_completions.push(PendingCompletion {
1768 id: pending_completion_id,
1769 queue_state: QueueState::Sending,
1770 _task: task,
1771 });
1772 }
1773
1774 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1775 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1776 println!("No thread summary model");
1777 return;
1778 };
1779
1780 if !model.provider.is_authenticated(cx) {
1781 return;
1782 }
1783
1784 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1785 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1786 If the conversation is about a specific subject, include it in the title. \
1787 Be descriptive. DO NOT speak in the first person.";
1788
1789 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1790
1791 self.summary = ThreadSummary::Generating;
1792
1793 self.pending_summary = cx.spawn(async move |this, cx| {
1794 let result = async {
1795 let mut messages = model.model.stream_completion(request, &cx).await?;
1796
1797 let mut new_summary = String::new();
1798 while let Some(event) = messages.next().await {
1799 let Ok(event) = event else {
1800 continue;
1801 };
1802 let text = match event {
1803 LanguageModelCompletionEvent::Text(text) => text,
1804 LanguageModelCompletionEvent::StatusUpdate(
1805 CompletionRequestStatus::UsageUpdated { amount, limit },
1806 ) => {
1807 this.update(cx, |thread, _cx| {
1808 thread.last_usage = Some(RequestUsage {
1809 limit,
1810 amount: amount as i32,
1811 });
1812 })?;
1813 continue;
1814 }
1815 _ => continue,
1816 };
1817
1818 let mut lines = text.lines();
1819 new_summary.extend(lines.next());
1820
1821 // Stop if the LLM generated multiple lines.
1822 if lines.next().is_some() {
1823 break;
1824 }
1825 }
1826
1827 anyhow::Ok(new_summary)
1828 }
1829 .await;
1830
1831 this.update(cx, |this, cx| {
1832 match result {
1833 Ok(new_summary) => {
1834 if new_summary.is_empty() {
1835 this.summary = ThreadSummary::Error;
1836 } else {
1837 this.summary = ThreadSummary::Ready(new_summary.into());
1838 }
1839 }
1840 Err(err) => {
1841 this.summary = ThreadSummary::Error;
1842 log::error!("Failed to generate thread summary: {}", err);
1843 }
1844 }
1845 cx.emit(ThreadEvent::SummaryGenerated);
1846 })
1847 .log_err()?;
1848
1849 Some(())
1850 });
1851 }
1852
1853 pub fn start_generating_detailed_summary_if_needed(
1854 &mut self,
1855 thread_store: WeakEntity<ThreadStore>,
1856 cx: &mut Context<Self>,
1857 ) {
1858 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1859 return;
1860 };
1861
1862 match &*self.detailed_summary_rx.borrow() {
1863 DetailedSummaryState::Generating { message_id, .. }
1864 | DetailedSummaryState::Generated { message_id, .. }
1865 if *message_id == last_message_id =>
1866 {
1867 // Already up-to-date
1868 return;
1869 }
1870 _ => {}
1871 }
1872
1873 let Some(ConfiguredModel { model, provider }) =
1874 LanguageModelRegistry::read_global(cx).thread_summary_model()
1875 else {
1876 return;
1877 };
1878
1879 if !provider.is_authenticated(cx) {
1880 return;
1881 }
1882
1883 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1884 1. A brief overview of what was discussed\n\
1885 2. Key facts or information discovered\n\
1886 3. Outcomes or conclusions reached\n\
1887 4. Any action items or next steps if any\n\
1888 Format it in Markdown with headings and bullet points.";
1889
1890 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1891
1892 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1893 message_id: last_message_id,
1894 };
1895
1896 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1897 // be better to allow the old task to complete, but this would require logic for choosing
1898 // which result to prefer (the old task could complete after the new one, resulting in a
1899 // stale summary).
1900 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1901 let stream = model.stream_completion_text(request, &cx);
1902 let Some(mut messages) = stream.await.log_err() else {
1903 thread
1904 .update(cx, |thread, _cx| {
1905 *thread.detailed_summary_tx.borrow_mut() =
1906 DetailedSummaryState::NotGenerated;
1907 })
1908 .ok()?;
1909 return None;
1910 };
1911
1912 let mut new_detailed_summary = String::new();
1913
1914 while let Some(chunk) = messages.stream.next().await {
1915 if let Some(chunk) = chunk.log_err() {
1916 new_detailed_summary.push_str(&chunk);
1917 }
1918 }
1919
1920 thread
1921 .update(cx, |thread, _cx| {
1922 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1923 text: new_detailed_summary.into(),
1924 message_id: last_message_id,
1925 };
1926 })
1927 .ok()?;
1928
1929 // Save thread so its summary can be reused later
1930 if let Some(thread) = thread.upgrade() {
1931 if let Ok(Ok(save_task)) = cx.update(|cx| {
1932 thread_store
1933 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1934 }) {
1935 save_task.await.log_err();
1936 }
1937 }
1938
1939 Some(())
1940 });
1941 }
1942
1943 pub async fn wait_for_detailed_summary_or_text(
1944 this: &Entity<Self>,
1945 cx: &mut AsyncApp,
1946 ) -> Option<SharedString> {
1947 let mut detailed_summary_rx = this
1948 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1949 .ok()?;
1950 loop {
1951 match detailed_summary_rx.recv().await? {
1952 DetailedSummaryState::Generating { .. } => {}
1953 DetailedSummaryState::NotGenerated => {
1954 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1955 }
1956 DetailedSummaryState::Generated { text, .. } => return Some(text),
1957 }
1958 }
1959 }
1960
1961 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1962 self.detailed_summary_rx
1963 .borrow()
1964 .text()
1965 .unwrap_or_else(|| self.text().into())
1966 }
1967
1968 pub fn is_generating_detailed_summary(&self) -> bool {
1969 matches!(
1970 &*self.detailed_summary_rx.borrow(),
1971 DetailedSummaryState::Generating { .. }
1972 )
1973 }
1974
1975 pub fn use_pending_tools(
1976 &mut self,
1977 window: Option<AnyWindowHandle>,
1978 cx: &mut Context<Self>,
1979 model: Arc<dyn LanguageModel>,
1980 ) -> Vec<PendingToolUse> {
1981 self.auto_capture_telemetry(cx);
1982 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1983 let pending_tool_uses = self
1984 .tool_use
1985 .pending_tool_uses()
1986 .into_iter()
1987 .filter(|tool_use| tool_use.status.is_idle())
1988 .cloned()
1989 .collect::<Vec<_>>();
1990
1991 for tool_use in pending_tool_uses.iter() {
1992 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1993 if tool.needs_confirmation(&tool_use.input, cx)
1994 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1995 {
1996 self.tool_use.confirm_tool_use(
1997 tool_use.id.clone(),
1998 tool_use.ui_text.clone(),
1999 tool_use.input.clone(),
2000 request.clone(),
2001 tool,
2002 );
2003 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2004 } else {
2005 self.run_tool(
2006 tool_use.id.clone(),
2007 tool_use.ui_text.clone(),
2008 tool_use.input.clone(),
2009 request.clone(),
2010 tool,
2011 model.clone(),
2012 window,
2013 cx,
2014 );
2015 }
2016 } else {
2017 self.handle_hallucinated_tool_use(
2018 tool_use.id.clone(),
2019 tool_use.name.clone(),
2020 window,
2021 cx,
2022 );
2023 }
2024 }
2025
2026 pending_tool_uses
2027 }
2028
2029 pub fn handle_hallucinated_tool_use(
2030 &mut self,
2031 tool_use_id: LanguageModelToolUseId,
2032 hallucinated_tool_name: Arc<str>,
2033 window: Option<AnyWindowHandle>,
2034 cx: &mut Context<Thread>,
2035 ) {
2036 let available_tools = self.tools.read(cx).enabled_tools(cx);
2037
2038 let tool_list = available_tools
2039 .iter()
2040 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2041 .collect::<Vec<_>>()
2042 .join("\n");
2043
2044 let error_message = format!(
2045 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2046 hallucinated_tool_name, tool_list
2047 );
2048
2049 let pending_tool_use = self.tool_use.insert_tool_output(
2050 tool_use_id.clone(),
2051 hallucinated_tool_name,
2052 Err(anyhow!("Missing tool call: {error_message}")),
2053 self.configured_model.as_ref(),
2054 );
2055
2056 cx.emit(ThreadEvent::MissingToolUse {
2057 tool_use_id: tool_use_id.clone(),
2058 ui_text: error_message.into(),
2059 });
2060
2061 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2062 }
2063
2064 pub fn receive_invalid_tool_json(
2065 &mut self,
2066 tool_use_id: LanguageModelToolUseId,
2067 tool_name: Arc<str>,
2068 invalid_json: Arc<str>,
2069 error: String,
2070 window: Option<AnyWindowHandle>,
2071 cx: &mut Context<Thread>,
2072 ) {
2073 log::error!("The model returned invalid input JSON: {invalid_json}");
2074
2075 let pending_tool_use = self.tool_use.insert_tool_output(
2076 tool_use_id.clone(),
2077 tool_name,
2078 Err(anyhow!("Error parsing input JSON: {error}")),
2079 self.configured_model.as_ref(),
2080 );
2081 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2082 pending_tool_use.ui_text.clone()
2083 } else {
2084 log::error!(
2085 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2086 );
2087 format!("Unknown tool {}", tool_use_id).into()
2088 };
2089
2090 cx.emit(ThreadEvent::InvalidToolInput {
2091 tool_use_id: tool_use_id.clone(),
2092 ui_text,
2093 invalid_input_json: invalid_json,
2094 });
2095
2096 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2097 }
2098
2099 pub fn run_tool(
2100 &mut self,
2101 tool_use_id: LanguageModelToolUseId,
2102 ui_text: impl Into<SharedString>,
2103 input: serde_json::Value,
2104 request: Arc<LanguageModelRequest>,
2105 tool: Arc<dyn Tool>,
2106 model: Arc<dyn LanguageModel>,
2107 window: Option<AnyWindowHandle>,
2108 cx: &mut Context<Thread>,
2109 ) {
2110 let task =
2111 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2112 self.tool_use
2113 .run_pending_tool(tool_use_id, ui_text.into(), task);
2114 }
2115
2116 fn spawn_tool_use(
2117 &mut self,
2118 tool_use_id: LanguageModelToolUseId,
2119 request: Arc<LanguageModelRequest>,
2120 input: serde_json::Value,
2121 tool: Arc<dyn Tool>,
2122 model: Arc<dyn LanguageModel>,
2123 window: Option<AnyWindowHandle>,
2124 cx: &mut Context<Thread>,
2125 ) -> Task<()> {
2126 let tool_name: Arc<str> = tool.name().into();
2127
2128 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2129 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2130 } else {
2131 tool.run(
2132 input,
2133 request,
2134 self.project.clone(),
2135 self.action_log.clone(),
2136 model,
2137 window,
2138 cx,
2139 )
2140 };
2141
2142 // Store the card separately if it exists
2143 if let Some(card) = tool_result.card.clone() {
2144 self.tool_use
2145 .insert_tool_result_card(tool_use_id.clone(), card);
2146 }
2147
2148 cx.spawn({
2149 async move |thread: WeakEntity<Thread>, cx| {
2150 let output = tool_result.output.await;
2151
2152 thread
2153 .update(cx, |thread, cx| {
2154 let pending_tool_use = thread.tool_use.insert_tool_output(
2155 tool_use_id.clone(),
2156 tool_name,
2157 output,
2158 thread.configured_model.as_ref(),
2159 );
2160 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2161 })
2162 .ok();
2163 }
2164 })
2165 }
2166
2167 fn tool_finished(
2168 &mut self,
2169 tool_use_id: LanguageModelToolUseId,
2170 pending_tool_use: Option<PendingToolUse>,
2171 canceled: bool,
2172 window: Option<AnyWindowHandle>,
2173 cx: &mut Context<Self>,
2174 ) {
2175 if self.all_tools_finished() {
2176 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2177 if !canceled {
2178 self.send_to_model(model.clone(), window, cx);
2179 }
2180 self.auto_capture_telemetry(cx);
2181 }
2182 }
2183
2184 cx.emit(ThreadEvent::ToolFinished {
2185 tool_use_id,
2186 pending_tool_use,
2187 });
2188 }
2189
2190 /// Cancels the last pending completion, if there are any pending.
2191 ///
2192 /// Returns whether a completion was canceled.
2193 pub fn cancel_last_completion(
2194 &mut self,
2195 window: Option<AnyWindowHandle>,
2196 cx: &mut Context<Self>,
2197 ) -> bool {
2198 let mut canceled = self.pending_completions.pop().is_some();
2199
2200 for pending_tool_use in self.tool_use.cancel_pending() {
2201 canceled = true;
2202 self.tool_finished(
2203 pending_tool_use.id.clone(),
2204 Some(pending_tool_use),
2205 true,
2206 window,
2207 cx,
2208 );
2209 }
2210
2211 self.finalize_pending_checkpoint(cx);
2212
2213 if canceled {
2214 cx.emit(ThreadEvent::CompletionCanceled);
2215 }
2216
2217 canceled
2218 }
2219
2220 /// Signals that any in-progress editing should be canceled.
2221 ///
2222 /// This method is used to notify listeners (like ActiveThread) that
2223 /// they should cancel any editing operations.
2224 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2225 cx.emit(ThreadEvent::CancelEditing);
2226 }
2227
2228 pub fn feedback(&self) -> Option<ThreadFeedback> {
2229 self.feedback
2230 }
2231
2232 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2233 self.message_feedback.get(&message_id).copied()
2234 }
2235
2236 pub fn report_message_feedback(
2237 &mut self,
2238 message_id: MessageId,
2239 feedback: ThreadFeedback,
2240 cx: &mut Context<Self>,
2241 ) -> Task<Result<()>> {
2242 if self.message_feedback.get(&message_id) == Some(&feedback) {
2243 return Task::ready(Ok(()));
2244 }
2245
2246 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2247 let serialized_thread = self.serialize(cx);
2248 let thread_id = self.id().clone();
2249 let client = self.project.read(cx).client();
2250
2251 let enabled_tool_names: Vec<String> = self
2252 .tools()
2253 .read(cx)
2254 .enabled_tools(cx)
2255 .iter()
2256 .map(|tool| tool.name())
2257 .collect();
2258
2259 self.message_feedback.insert(message_id, feedback);
2260
2261 cx.notify();
2262
2263 let message_content = self
2264 .message(message_id)
2265 .map(|msg| msg.to_string())
2266 .unwrap_or_default();
2267
2268 cx.background_spawn(async move {
2269 let final_project_snapshot = final_project_snapshot.await;
2270 let serialized_thread = serialized_thread.await?;
2271 let thread_data =
2272 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2273
2274 let rating = match feedback {
2275 ThreadFeedback::Positive => "positive",
2276 ThreadFeedback::Negative => "negative",
2277 };
2278 telemetry::event!(
2279 "Assistant Thread Rated",
2280 rating,
2281 thread_id,
2282 enabled_tool_names,
2283 message_id = message_id.0,
2284 message_content,
2285 thread_data,
2286 final_project_snapshot
2287 );
2288 client.telemetry().flush_events().await;
2289
2290 Ok(())
2291 })
2292 }
2293
2294 pub fn report_feedback(
2295 &mut self,
2296 feedback: ThreadFeedback,
2297 cx: &mut Context<Self>,
2298 ) -> Task<Result<()>> {
2299 let last_assistant_message_id = self
2300 .messages
2301 .iter()
2302 .rev()
2303 .find(|msg| msg.role == Role::Assistant)
2304 .map(|msg| msg.id);
2305
2306 if let Some(message_id) = last_assistant_message_id {
2307 self.report_message_feedback(message_id, feedback, cx)
2308 } else {
2309 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2310 let serialized_thread = self.serialize(cx);
2311 let thread_id = self.id().clone();
2312 let client = self.project.read(cx).client();
2313 self.feedback = Some(feedback);
2314 cx.notify();
2315
2316 cx.background_spawn(async move {
2317 let final_project_snapshot = final_project_snapshot.await;
2318 let serialized_thread = serialized_thread.await?;
2319 let thread_data = serde_json::to_value(serialized_thread)
2320 .unwrap_or_else(|_| serde_json::Value::Null);
2321
2322 let rating = match feedback {
2323 ThreadFeedback::Positive => "positive",
2324 ThreadFeedback::Negative => "negative",
2325 };
2326 telemetry::event!(
2327 "Assistant Thread Rated",
2328 rating,
2329 thread_id,
2330 thread_data,
2331 final_project_snapshot
2332 );
2333 client.telemetry().flush_events().await;
2334
2335 Ok(())
2336 })
2337 }
2338 }
2339
2340 /// Create a snapshot of the current project state including git information and unsaved buffers.
2341 fn project_snapshot(
2342 project: Entity<Project>,
2343 cx: &mut Context<Self>,
2344 ) -> Task<Arc<ProjectSnapshot>> {
2345 let git_store = project.read(cx).git_store().clone();
2346 let worktree_snapshots: Vec<_> = project
2347 .read(cx)
2348 .visible_worktrees(cx)
2349 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2350 .collect();
2351
2352 cx.spawn(async move |_, cx| {
2353 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2354
2355 let mut unsaved_buffers = Vec::new();
2356 cx.update(|app_cx| {
2357 let buffer_store = project.read(app_cx).buffer_store();
2358 for buffer_handle in buffer_store.read(app_cx).buffers() {
2359 let buffer = buffer_handle.read(app_cx);
2360 if buffer.is_dirty() {
2361 if let Some(file) = buffer.file() {
2362 let path = file.path().to_string_lossy().to_string();
2363 unsaved_buffers.push(path);
2364 }
2365 }
2366 }
2367 })
2368 .ok();
2369
2370 Arc::new(ProjectSnapshot {
2371 worktree_snapshots,
2372 unsaved_buffer_paths: unsaved_buffers,
2373 timestamp: Utc::now(),
2374 })
2375 })
2376 }
2377
2378 fn worktree_snapshot(
2379 worktree: Entity<project::Worktree>,
2380 git_store: Entity<GitStore>,
2381 cx: &App,
2382 ) -> Task<WorktreeSnapshot> {
2383 cx.spawn(async move |cx| {
2384 // Get worktree path and snapshot
2385 let worktree_info = cx.update(|app_cx| {
2386 let worktree = worktree.read(app_cx);
2387 let path = worktree.abs_path().to_string_lossy().to_string();
2388 let snapshot = worktree.snapshot();
2389 (path, snapshot)
2390 });
2391
2392 let Ok((worktree_path, _snapshot)) = worktree_info else {
2393 return WorktreeSnapshot {
2394 worktree_path: String::new(),
2395 git_state: None,
2396 };
2397 };
2398
2399 let git_state = git_store
2400 .update(cx, |git_store, cx| {
2401 git_store
2402 .repositories()
2403 .values()
2404 .find(|repo| {
2405 repo.read(cx)
2406 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2407 .is_some()
2408 })
2409 .cloned()
2410 })
2411 .ok()
2412 .flatten()
2413 .map(|repo| {
2414 repo.update(cx, |repo, _| {
2415 let current_branch =
2416 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2417 repo.send_job(None, |state, _| async move {
2418 let RepositoryState::Local { backend, .. } = state else {
2419 return GitState {
2420 remote_url: None,
2421 head_sha: None,
2422 current_branch,
2423 diff: None,
2424 };
2425 };
2426
2427 let remote_url = backend.remote_url("origin");
2428 let head_sha = backend.head_sha().await;
2429 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2430
2431 GitState {
2432 remote_url,
2433 head_sha,
2434 current_branch,
2435 diff,
2436 }
2437 })
2438 })
2439 });
2440
2441 let git_state = match git_state {
2442 Some(git_state) => match git_state.ok() {
2443 Some(git_state) => git_state.await.ok(),
2444 None => None,
2445 },
2446 None => None,
2447 };
2448
2449 WorktreeSnapshot {
2450 worktree_path,
2451 git_state,
2452 }
2453 })
2454 }
2455
2456 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2457 let mut markdown = Vec::new();
2458
2459 let summary = self.summary().or_default();
2460 writeln!(markdown, "# {summary}\n")?;
2461
2462 for message in self.messages() {
2463 writeln!(
2464 markdown,
2465 "## {role}\n",
2466 role = match message.role {
2467 Role::User => "User",
2468 Role::Assistant => "Agent",
2469 Role::System => "System",
2470 }
2471 )?;
2472
2473 if !message.loaded_context.text.is_empty() {
2474 writeln!(markdown, "{}", message.loaded_context.text)?;
2475 }
2476
2477 if !message.loaded_context.images.is_empty() {
2478 writeln!(
2479 markdown,
2480 "\n{} images attached as context.\n",
2481 message.loaded_context.images.len()
2482 )?;
2483 }
2484
2485 for segment in &message.segments {
2486 match segment {
2487 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2488 MessageSegment::Thinking { text, .. } => {
2489 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2490 }
2491 MessageSegment::RedactedThinking(_) => {}
2492 }
2493 }
2494
2495 for tool_use in self.tool_uses_for_message(message.id, cx) {
2496 writeln!(
2497 markdown,
2498 "**Use Tool: {} ({})**",
2499 tool_use.name, tool_use.id
2500 )?;
2501 writeln!(markdown, "```json")?;
2502 writeln!(
2503 markdown,
2504 "{}",
2505 serde_json::to_string_pretty(&tool_use.input)?
2506 )?;
2507 writeln!(markdown, "```")?;
2508 }
2509
2510 for tool_result in self.tool_results_for_message(message.id) {
2511 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2512 if tool_result.is_error {
2513 write!(markdown, " (Error)")?;
2514 }
2515
2516 writeln!(markdown, "**\n")?;
2517 match &tool_result.content {
2518 LanguageModelToolResultContent::Text(str) => {
2519 writeln!(markdown, "{}", str)?;
2520 }
2521 LanguageModelToolResultContent::Image(image) => {
2522 writeln!(markdown, "", image.source)?;
2523 }
2524 }
2525
2526 if let Some(output) = tool_result.output.as_ref() {
2527 writeln!(
2528 markdown,
2529 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2530 serde_json::to_string_pretty(output)?
2531 )?;
2532 }
2533 }
2534 }
2535
2536 Ok(String::from_utf8_lossy(&markdown).to_string())
2537 }
2538
2539 pub fn keep_edits_in_range(
2540 &mut self,
2541 buffer: Entity<language::Buffer>,
2542 buffer_range: Range<language::Anchor>,
2543 cx: &mut Context<Self>,
2544 ) {
2545 self.action_log.update(cx, |action_log, cx| {
2546 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2547 });
2548 }
2549
2550 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2551 self.action_log
2552 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2553 }
2554
2555 pub fn reject_edits_in_ranges(
2556 &mut self,
2557 buffer: Entity<language::Buffer>,
2558 buffer_ranges: Vec<Range<language::Anchor>>,
2559 cx: &mut Context<Self>,
2560 ) -> Task<Result<()>> {
2561 self.action_log.update(cx, |action_log, cx| {
2562 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2563 })
2564 }
2565
2566 pub fn action_log(&self) -> &Entity<ActionLog> {
2567 &self.action_log
2568 }
2569
2570 pub fn project(&self) -> &Entity<Project> {
2571 &self.project
2572 }
2573
2574 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2575 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2576 return;
2577 }
2578
2579 let now = Instant::now();
2580 if let Some(last) = self.last_auto_capture_at {
2581 if now.duration_since(last).as_secs() < 10 {
2582 return;
2583 }
2584 }
2585
2586 self.last_auto_capture_at = Some(now);
2587
2588 let thread_id = self.id().clone();
2589 let github_login = self
2590 .project
2591 .read(cx)
2592 .user_store()
2593 .read(cx)
2594 .current_user()
2595 .map(|user| user.github_login.clone());
2596 let client = self.project.read(cx).client();
2597 let serialize_task = self.serialize(cx);
2598
2599 cx.background_executor()
2600 .spawn(async move {
2601 if let Ok(serialized_thread) = serialize_task.await {
2602 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2603 telemetry::event!(
2604 "Agent Thread Auto-Captured",
2605 thread_id = thread_id.to_string(),
2606 thread_data = thread_data,
2607 auto_capture_reason = "tracked_user",
2608 github_login = github_login
2609 );
2610
2611 client.telemetry().flush_events().await;
2612 }
2613 }
2614 })
2615 .detach();
2616 }
2617
2618 pub fn cumulative_token_usage(&self) -> TokenUsage {
2619 self.cumulative_token_usage
2620 }
2621
2622 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2623 let Some(model) = self.configured_model.as_ref() else {
2624 return TotalTokenUsage::default();
2625 };
2626
2627 let max = model.model.max_token_count();
2628
2629 let index = self
2630 .messages
2631 .iter()
2632 .position(|msg| msg.id == message_id)
2633 .unwrap_or(0);
2634
2635 if index == 0 {
2636 return TotalTokenUsage { total: 0, max };
2637 }
2638
2639 let token_usage = &self
2640 .request_token_usage
2641 .get(index - 1)
2642 .cloned()
2643 .unwrap_or_default();
2644
2645 TotalTokenUsage {
2646 total: token_usage.total_tokens() as usize,
2647 max,
2648 }
2649 }
2650
2651 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2652 let model = self.configured_model.as_ref()?;
2653
2654 let max = model.model.max_token_count();
2655
2656 if let Some(exceeded_error) = &self.exceeded_window_error {
2657 if model.model.id() == exceeded_error.model_id {
2658 return Some(TotalTokenUsage {
2659 total: exceeded_error.token_count,
2660 max,
2661 });
2662 }
2663 }
2664
2665 let total = self
2666 .token_usage_at_last_message()
2667 .unwrap_or_default()
2668 .total_tokens() as usize;
2669
2670 Some(TotalTokenUsage { total, max })
2671 }
2672
2673 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2674 self.request_token_usage
2675 .get(self.messages.len().saturating_sub(1))
2676 .or_else(|| self.request_token_usage.last())
2677 .cloned()
2678 }
2679
2680 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2681 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2682 self.request_token_usage
2683 .resize(self.messages.len(), placeholder);
2684
2685 if let Some(last) = self.request_token_usage.last_mut() {
2686 *last = token_usage;
2687 }
2688 }
2689
2690 pub fn deny_tool_use(
2691 &mut self,
2692 tool_use_id: LanguageModelToolUseId,
2693 tool_name: Arc<str>,
2694 window: Option<AnyWindowHandle>,
2695 cx: &mut Context<Self>,
2696 ) {
2697 let err = Err(anyhow::anyhow!(
2698 "Permission to run tool action denied by user"
2699 ));
2700
2701 self.tool_use.insert_tool_output(
2702 tool_use_id.clone(),
2703 tool_name,
2704 err,
2705 self.configured_model.as_ref(),
2706 );
2707 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2708 }
2709}
2710
2711#[derive(Debug, Clone, Error)]
2712pub enum ThreadError {
2713 #[error("Payment required")]
2714 PaymentRequired,
2715 #[error("Model request limit reached")]
2716 ModelRequestLimitReached { plan: Plan },
2717 #[error("Message {header}: {message}")]
2718 Message {
2719 header: SharedString,
2720 message: SharedString,
2721 },
2722}
2723
2724#[derive(Debug, Clone)]
2725pub enum ThreadEvent {
2726 ShowError(ThreadError),
2727 StreamedCompletion,
2728 ReceivedTextChunk,
2729 NewRequest,
2730 StreamedAssistantText(MessageId, String),
2731 StreamedAssistantThinking(MessageId, String),
2732 StreamedToolUse {
2733 tool_use_id: LanguageModelToolUseId,
2734 ui_text: Arc<str>,
2735 input: serde_json::Value,
2736 },
2737 MissingToolUse {
2738 tool_use_id: LanguageModelToolUseId,
2739 ui_text: Arc<str>,
2740 },
2741 InvalidToolInput {
2742 tool_use_id: LanguageModelToolUseId,
2743 ui_text: Arc<str>,
2744 invalid_input_json: Arc<str>,
2745 },
2746 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2747 MessageAdded(MessageId),
2748 MessageEdited(MessageId),
2749 MessageDeleted(MessageId),
2750 SummaryGenerated,
2751 SummaryChanged,
2752 UsePendingTools {
2753 tool_uses: Vec<PendingToolUse>,
2754 },
2755 ToolFinished {
2756 #[allow(unused)]
2757 tool_use_id: LanguageModelToolUseId,
2758 /// The pending tool use that corresponds to this tool.
2759 pending_tool_use: Option<PendingToolUse>,
2760 },
2761 CheckpointChanged,
2762 ToolConfirmationNeeded,
2763 CancelEditing,
2764 CompletionCanceled,
2765}
2766
2767impl EventEmitter<ThreadEvent> for Thread {}
2768
2769struct PendingCompletion {
2770 id: usize,
2771 queue_state: QueueState,
2772 _task: Task<()>,
2773}
2774
2775#[cfg(test)]
2776mod tests {
2777 use super::*;
2778 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2779 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2780 use assistant_tool::ToolRegistry;
2781 use editor::EditorSettings;
2782 use gpui::TestAppContext;
2783 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2784 use project::{FakeFs, Project};
2785 use prompt_store::PromptBuilder;
2786 use serde_json::json;
2787 use settings::{Settings, SettingsStore};
2788 use std::sync::Arc;
2789 use theme::ThemeSettings;
2790 use util::path;
2791 use workspace::Workspace;
2792
2793 #[gpui::test]
2794 async fn test_message_with_context(cx: &mut TestAppContext) {
2795 init_test_settings(cx);
2796
2797 let project = create_test_project(
2798 cx,
2799 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2800 )
2801 .await;
2802
2803 let (_workspace, _thread_store, thread, context_store, model) =
2804 setup_test_environment(cx, project.clone()).await;
2805
2806 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2807 .await
2808 .unwrap();
2809
2810 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2811 let loaded_context = cx
2812 .update(|cx| load_context(vec![context], &project, &None, cx))
2813 .await;
2814
2815 // Insert user message with context
2816 let message_id = thread.update(cx, |thread, cx| {
2817 thread.insert_user_message(
2818 "Please explain this code",
2819 loaded_context,
2820 None,
2821 Vec::new(),
2822 cx,
2823 )
2824 });
2825
2826 // Check content and context in message object
2827 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2828
2829 // Use different path format strings based on platform for the test
2830 #[cfg(windows)]
2831 let path_part = r"test\code.rs";
2832 #[cfg(not(windows))]
2833 let path_part = "test/code.rs";
2834
2835 let expected_context = format!(
2836 r#"
2837<context>
2838The following items were attached by the user. They are up-to-date and don't need to be re-read.
2839
2840<files>
2841```rs {path_part}
2842fn main() {{
2843 println!("Hello, world!");
2844}}
2845```
2846</files>
2847</context>
2848"#
2849 );
2850
2851 assert_eq!(message.role, Role::User);
2852 assert_eq!(message.segments.len(), 1);
2853 assert_eq!(
2854 message.segments[0],
2855 MessageSegment::Text("Please explain this code".to_string())
2856 );
2857 assert_eq!(message.loaded_context.text, expected_context);
2858
2859 // Check message in request
2860 let request = thread.update(cx, |thread, cx| {
2861 thread.to_completion_request(model.clone(), cx)
2862 });
2863
2864 assert_eq!(request.messages.len(), 2);
2865 let expected_full_message = format!("{}Please explain this code", expected_context);
2866 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2867 }
2868
2869 #[gpui::test]
2870 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2871 init_test_settings(cx);
2872
2873 let project = create_test_project(
2874 cx,
2875 json!({
2876 "file1.rs": "fn function1() {}\n",
2877 "file2.rs": "fn function2() {}\n",
2878 "file3.rs": "fn function3() {}\n",
2879 "file4.rs": "fn function4() {}\n",
2880 }),
2881 )
2882 .await;
2883
2884 let (_, _thread_store, thread, context_store, model) =
2885 setup_test_environment(cx, project.clone()).await;
2886
2887 // First message with context 1
2888 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2889 .await
2890 .unwrap();
2891 let new_contexts = context_store.update(cx, |store, cx| {
2892 store.new_context_for_thread(thread.read(cx), None)
2893 });
2894 assert_eq!(new_contexts.len(), 1);
2895 let loaded_context = cx
2896 .update(|cx| load_context(new_contexts, &project, &None, cx))
2897 .await;
2898 let message1_id = thread.update(cx, |thread, cx| {
2899 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2900 });
2901
2902 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2903 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2904 .await
2905 .unwrap();
2906 let new_contexts = context_store.update(cx, |store, cx| {
2907 store.new_context_for_thread(thread.read(cx), None)
2908 });
2909 assert_eq!(new_contexts.len(), 1);
2910 let loaded_context = cx
2911 .update(|cx| load_context(new_contexts, &project, &None, cx))
2912 .await;
2913 let message2_id = thread.update(cx, |thread, cx| {
2914 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2915 });
2916
2917 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2918 //
2919 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2920 .await
2921 .unwrap();
2922 let new_contexts = context_store.update(cx, |store, cx| {
2923 store.new_context_for_thread(thread.read(cx), None)
2924 });
2925 assert_eq!(new_contexts.len(), 1);
2926 let loaded_context = cx
2927 .update(|cx| load_context(new_contexts, &project, &None, cx))
2928 .await;
2929 let message3_id = thread.update(cx, |thread, cx| {
2930 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2931 });
2932
2933 // Check what contexts are included in each message
2934 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2935 (
2936 thread.message(message1_id).unwrap().clone(),
2937 thread.message(message2_id).unwrap().clone(),
2938 thread.message(message3_id).unwrap().clone(),
2939 )
2940 });
2941
2942 // First message should include context 1
2943 assert!(message1.loaded_context.text.contains("file1.rs"));
2944
2945 // Second message should include only context 2 (not 1)
2946 assert!(!message2.loaded_context.text.contains("file1.rs"));
2947 assert!(message2.loaded_context.text.contains("file2.rs"));
2948
2949 // Third message should include only context 3 (not 1 or 2)
2950 assert!(!message3.loaded_context.text.contains("file1.rs"));
2951 assert!(!message3.loaded_context.text.contains("file2.rs"));
2952 assert!(message3.loaded_context.text.contains("file3.rs"));
2953
2954 // Check entire request to make sure all contexts are properly included
2955 let request = thread.update(cx, |thread, cx| {
2956 thread.to_completion_request(model.clone(), cx)
2957 });
2958
2959 // The request should contain all 3 messages
2960 assert_eq!(request.messages.len(), 4);
2961
2962 // Check that the contexts are properly formatted in each message
2963 assert!(request.messages[1].string_contents().contains("file1.rs"));
2964 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2965 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2966
2967 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2968 assert!(request.messages[2].string_contents().contains("file2.rs"));
2969 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2970
2971 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2972 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2973 assert!(request.messages[3].string_contents().contains("file3.rs"));
2974
2975 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2976 .await
2977 .unwrap();
2978 let new_contexts = context_store.update(cx, |store, cx| {
2979 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2980 });
2981 assert_eq!(new_contexts.len(), 3);
2982 let loaded_context = cx
2983 .update(|cx| load_context(new_contexts, &project, &None, cx))
2984 .await
2985 .loaded_context;
2986
2987 assert!(!loaded_context.text.contains("file1.rs"));
2988 assert!(loaded_context.text.contains("file2.rs"));
2989 assert!(loaded_context.text.contains("file3.rs"));
2990 assert!(loaded_context.text.contains("file4.rs"));
2991
2992 let new_contexts = context_store.update(cx, |store, cx| {
2993 // Remove file4.rs
2994 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2995 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2996 });
2997 assert_eq!(new_contexts.len(), 2);
2998 let loaded_context = cx
2999 .update(|cx| load_context(new_contexts, &project, &None, cx))
3000 .await
3001 .loaded_context;
3002
3003 assert!(!loaded_context.text.contains("file1.rs"));
3004 assert!(loaded_context.text.contains("file2.rs"));
3005 assert!(loaded_context.text.contains("file3.rs"));
3006 assert!(!loaded_context.text.contains("file4.rs"));
3007
3008 let new_contexts = context_store.update(cx, |store, cx| {
3009 // Remove file3.rs
3010 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3011 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3012 });
3013 assert_eq!(new_contexts.len(), 1);
3014 let loaded_context = cx
3015 .update(|cx| load_context(new_contexts, &project, &None, cx))
3016 .await
3017 .loaded_context;
3018
3019 assert!(!loaded_context.text.contains("file1.rs"));
3020 assert!(loaded_context.text.contains("file2.rs"));
3021 assert!(!loaded_context.text.contains("file3.rs"));
3022 assert!(!loaded_context.text.contains("file4.rs"));
3023 }
3024
3025 #[gpui::test]
3026 async fn test_message_without_files(cx: &mut TestAppContext) {
3027 init_test_settings(cx);
3028
3029 let project = create_test_project(
3030 cx,
3031 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3032 )
3033 .await;
3034
3035 let (_, _thread_store, thread, _context_store, model) =
3036 setup_test_environment(cx, project.clone()).await;
3037
3038 // Insert user message without any context (empty context vector)
3039 let message_id = thread.update(cx, |thread, cx| {
3040 thread.insert_user_message(
3041 "What is the best way to learn Rust?",
3042 ContextLoadResult::default(),
3043 None,
3044 Vec::new(),
3045 cx,
3046 )
3047 });
3048
3049 // Check content and context in message object
3050 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3051
3052 // Context should be empty when no files are included
3053 assert_eq!(message.role, Role::User);
3054 assert_eq!(message.segments.len(), 1);
3055 assert_eq!(
3056 message.segments[0],
3057 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3058 );
3059 assert_eq!(message.loaded_context.text, "");
3060
3061 // Check message in request
3062 let request = thread.update(cx, |thread, cx| {
3063 thread.to_completion_request(model.clone(), cx)
3064 });
3065
3066 assert_eq!(request.messages.len(), 2);
3067 assert_eq!(
3068 request.messages[1].string_contents(),
3069 "What is the best way to learn Rust?"
3070 );
3071
3072 // Add second message, also without context
3073 let message2_id = thread.update(cx, |thread, cx| {
3074 thread.insert_user_message(
3075 "Are there any good books?",
3076 ContextLoadResult::default(),
3077 None,
3078 Vec::new(),
3079 cx,
3080 )
3081 });
3082
3083 let message2 =
3084 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3085 assert_eq!(message2.loaded_context.text, "");
3086
3087 // Check that both messages appear in the request
3088 let request = thread.update(cx, |thread, cx| {
3089 thread.to_completion_request(model.clone(), cx)
3090 });
3091
3092 assert_eq!(request.messages.len(), 3);
3093 assert_eq!(
3094 request.messages[1].string_contents(),
3095 "What is the best way to learn Rust?"
3096 );
3097 assert_eq!(
3098 request.messages[2].string_contents(),
3099 "Are there any good books?"
3100 );
3101 }
3102
3103 #[gpui::test]
3104 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3105 init_test_settings(cx);
3106
3107 let project = create_test_project(
3108 cx,
3109 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3110 )
3111 .await;
3112
3113 let (_workspace, _thread_store, thread, context_store, model) =
3114 setup_test_environment(cx, project.clone()).await;
3115
3116 // Open buffer and add it to context
3117 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3118 .await
3119 .unwrap();
3120
3121 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3122 let loaded_context = cx
3123 .update(|cx| load_context(vec![context], &project, &None, cx))
3124 .await;
3125
3126 // Insert user message with the buffer as context
3127 thread.update(cx, |thread, cx| {
3128 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3129 });
3130
3131 // Create a request and check that it doesn't have a stale buffer warning yet
3132 let initial_request = thread.update(cx, |thread, cx| {
3133 thread.to_completion_request(model.clone(), cx)
3134 });
3135
3136 // Make sure we don't have a stale file warning yet
3137 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3138 msg.string_contents()
3139 .contains("These files changed since last read:")
3140 });
3141 assert!(
3142 !has_stale_warning,
3143 "Should not have stale buffer warning before buffer is modified"
3144 );
3145
3146 // Modify the buffer
3147 buffer.update(cx, |buffer, cx| {
3148 // Find a position at the end of line 1
3149 buffer.edit(
3150 [(1..1, "\n println!(\"Added a new line\");\n")],
3151 None,
3152 cx,
3153 );
3154 });
3155
3156 // Insert another user message without context
3157 thread.update(cx, |thread, cx| {
3158 thread.insert_user_message(
3159 "What does the code do now?",
3160 ContextLoadResult::default(),
3161 None,
3162 Vec::new(),
3163 cx,
3164 )
3165 });
3166
3167 // Create a new request and check for the stale buffer warning
3168 let new_request = thread.update(cx, |thread, cx| {
3169 thread.to_completion_request(model.clone(), cx)
3170 });
3171
3172 // We should have a stale file warning as the last message
3173 let last_message = new_request
3174 .messages
3175 .last()
3176 .expect("Request should have messages");
3177
3178 // The last message should be the stale buffer notification
3179 assert_eq!(last_message.role, Role::User);
3180
3181 // Check the exact content of the message
3182 let expected_content = "These files changed since last read:\n- code.rs\n";
3183 assert_eq!(
3184 last_message.string_contents(),
3185 expected_content,
3186 "Last message should be exactly the stale buffer notification"
3187 );
3188 }
3189
3190 #[gpui::test]
3191 async fn test_temperature_setting(cx: &mut TestAppContext) {
3192 init_test_settings(cx);
3193
3194 let project = create_test_project(
3195 cx,
3196 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3197 )
3198 .await;
3199
3200 let (_workspace, _thread_store, thread, _context_store, model) =
3201 setup_test_environment(cx, project.clone()).await;
3202
3203 // Both model and provider
3204 cx.update(|cx| {
3205 AssistantSettings::override_global(
3206 AssistantSettings {
3207 model_parameters: vec![LanguageModelParameters {
3208 provider: Some(model.provider_id().0.to_string().into()),
3209 model: Some(model.id().0.clone()),
3210 temperature: Some(0.66),
3211 }],
3212 ..AssistantSettings::get_global(cx).clone()
3213 },
3214 cx,
3215 );
3216 });
3217
3218 let request = thread.update(cx, |thread, cx| {
3219 thread.to_completion_request(model.clone(), cx)
3220 });
3221 assert_eq!(request.temperature, Some(0.66));
3222
3223 // Only model
3224 cx.update(|cx| {
3225 AssistantSettings::override_global(
3226 AssistantSettings {
3227 model_parameters: vec![LanguageModelParameters {
3228 provider: None,
3229 model: Some(model.id().0.clone()),
3230 temperature: Some(0.66),
3231 }],
3232 ..AssistantSettings::get_global(cx).clone()
3233 },
3234 cx,
3235 );
3236 });
3237
3238 let request = thread.update(cx, |thread, cx| {
3239 thread.to_completion_request(model.clone(), cx)
3240 });
3241 assert_eq!(request.temperature, Some(0.66));
3242
3243 // Only provider
3244 cx.update(|cx| {
3245 AssistantSettings::override_global(
3246 AssistantSettings {
3247 model_parameters: vec![LanguageModelParameters {
3248 provider: Some(model.provider_id().0.to_string().into()),
3249 model: None,
3250 temperature: Some(0.66),
3251 }],
3252 ..AssistantSettings::get_global(cx).clone()
3253 },
3254 cx,
3255 );
3256 });
3257
3258 let request = thread.update(cx, |thread, cx| {
3259 thread.to_completion_request(model.clone(), cx)
3260 });
3261 assert_eq!(request.temperature, Some(0.66));
3262
3263 // Same model name, different provider
3264 cx.update(|cx| {
3265 AssistantSettings::override_global(
3266 AssistantSettings {
3267 model_parameters: vec![LanguageModelParameters {
3268 provider: Some("anthropic".into()),
3269 model: Some(model.id().0.clone()),
3270 temperature: Some(0.66),
3271 }],
3272 ..AssistantSettings::get_global(cx).clone()
3273 },
3274 cx,
3275 );
3276 });
3277
3278 let request = thread.update(cx, |thread, cx| {
3279 thread.to_completion_request(model.clone(), cx)
3280 });
3281 assert_eq!(request.temperature, None);
3282 }
3283
3284 #[gpui::test]
3285 async fn test_thread_summary(cx: &mut TestAppContext) {
3286 init_test_settings(cx);
3287
3288 let project = create_test_project(cx, json!({})).await;
3289
3290 let (_, _thread_store, thread, _context_store, model) =
3291 setup_test_environment(cx, project.clone()).await;
3292
3293 // Initial state should be pending
3294 thread.read_with(cx, |thread, _| {
3295 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3296 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3297 });
3298
3299 // Manually setting the summary should not be allowed in this state
3300 thread.update(cx, |thread, cx| {
3301 thread.set_summary("This should not work", cx);
3302 });
3303
3304 thread.read_with(cx, |thread, _| {
3305 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3306 });
3307
3308 // Send a message
3309 thread.update(cx, |thread, cx| {
3310 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3311 thread.send_to_model(model.clone(), None, cx);
3312 });
3313
3314 let fake_model = model.as_fake();
3315 simulate_successful_response(&fake_model, cx);
3316
3317 // Should start generating summary when there are >= 2 messages
3318 thread.read_with(cx, |thread, _| {
3319 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3320 });
3321
3322 // Should not be able to set the summary while generating
3323 thread.update(cx, |thread, cx| {
3324 thread.set_summary("This should not work either", cx);
3325 });
3326
3327 thread.read_with(cx, |thread, _| {
3328 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3329 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3330 });
3331
3332 cx.run_until_parked();
3333 fake_model.stream_last_completion_response("Brief".into());
3334 fake_model.stream_last_completion_response(" Introduction".into());
3335 fake_model.end_last_completion_stream();
3336 cx.run_until_parked();
3337
3338 // Summary should be set
3339 thread.read_with(cx, |thread, _| {
3340 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3341 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3342 });
3343
3344 // Now we should be able to set a summary
3345 thread.update(cx, |thread, cx| {
3346 thread.set_summary("Brief Intro", cx);
3347 });
3348
3349 thread.read_with(cx, |thread, _| {
3350 assert_eq!(thread.summary().or_default(), "Brief Intro");
3351 });
3352
3353 // Test setting an empty summary (should default to DEFAULT)
3354 thread.update(cx, |thread, cx| {
3355 thread.set_summary("", cx);
3356 });
3357
3358 thread.read_with(cx, |thread, _| {
3359 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3360 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3361 });
3362 }
3363
3364 #[gpui::test]
3365 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3366 init_test_settings(cx);
3367
3368 let project = create_test_project(cx, json!({})).await;
3369
3370 let (_, _thread_store, thread, _context_store, model) =
3371 setup_test_environment(cx, project.clone()).await;
3372
3373 test_summarize_error(&model, &thread, cx);
3374
3375 // Now we should be able to set a summary
3376 thread.update(cx, |thread, cx| {
3377 thread.set_summary("Brief Intro", cx);
3378 });
3379
3380 thread.read_with(cx, |thread, _| {
3381 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3382 assert_eq!(thread.summary().or_default(), "Brief Intro");
3383 });
3384 }
3385
3386 #[gpui::test]
3387 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3388 init_test_settings(cx);
3389
3390 let project = create_test_project(cx, json!({})).await;
3391
3392 let (_, _thread_store, thread, _context_store, model) =
3393 setup_test_environment(cx, project.clone()).await;
3394
3395 test_summarize_error(&model, &thread, cx);
3396
3397 // Sending another message should not trigger another summarize request
3398 thread.update(cx, |thread, cx| {
3399 thread.insert_user_message(
3400 "How are you?",
3401 ContextLoadResult::default(),
3402 None,
3403 vec![],
3404 cx,
3405 );
3406 thread.send_to_model(model.clone(), None, cx);
3407 });
3408
3409 let fake_model = model.as_fake();
3410 simulate_successful_response(&fake_model, cx);
3411
3412 thread.read_with(cx, |thread, _| {
3413 // State is still Error, not Generating
3414 assert!(matches!(thread.summary(), ThreadSummary::Error));
3415 });
3416
3417 // But the summarize request can be invoked manually
3418 thread.update(cx, |thread, cx| {
3419 thread.summarize(cx);
3420 });
3421
3422 thread.read_with(cx, |thread, _| {
3423 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3424 });
3425
3426 cx.run_until_parked();
3427 fake_model.stream_last_completion_response("A successful summary".into());
3428 fake_model.end_last_completion_stream();
3429 cx.run_until_parked();
3430
3431 thread.read_with(cx, |thread, _| {
3432 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3433 assert_eq!(thread.summary().or_default(), "A successful summary");
3434 });
3435 }
3436
3437 fn test_summarize_error(
3438 model: &Arc<dyn LanguageModel>,
3439 thread: &Entity<Thread>,
3440 cx: &mut TestAppContext,
3441 ) {
3442 thread.update(cx, |thread, cx| {
3443 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3444 thread.send_to_model(model.clone(), None, cx);
3445 });
3446
3447 let fake_model = model.as_fake();
3448 simulate_successful_response(&fake_model, cx);
3449
3450 thread.read_with(cx, |thread, _| {
3451 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3452 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3453 });
3454
3455 // Simulate summary request ending
3456 cx.run_until_parked();
3457 fake_model.end_last_completion_stream();
3458 cx.run_until_parked();
3459
3460 // State is set to Error and default message
3461 thread.read_with(cx, |thread, _| {
3462 assert!(matches!(thread.summary(), ThreadSummary::Error));
3463 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3464 });
3465 }
3466
3467 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3468 cx.run_until_parked();
3469 fake_model.stream_last_completion_response("Assistant response".into());
3470 fake_model.end_last_completion_stream();
3471 cx.run_until_parked();
3472 }
3473
3474 fn init_test_settings(cx: &mut TestAppContext) {
3475 cx.update(|cx| {
3476 let settings_store = SettingsStore::test(cx);
3477 cx.set_global(settings_store);
3478 language::init(cx);
3479 Project::init_settings(cx);
3480 AssistantSettings::register(cx);
3481 prompt_store::init(cx);
3482 thread_store::init(cx);
3483 workspace::init_settings(cx);
3484 language_model::init_settings(cx);
3485 ThemeSettings::register(cx);
3486 EditorSettings::register(cx);
3487 ToolRegistry::default_global(cx);
3488 });
3489 }
3490
3491 // Helper to create a test project with test files
3492 async fn create_test_project(
3493 cx: &mut TestAppContext,
3494 files: serde_json::Value,
3495 ) -> Entity<Project> {
3496 let fs = FakeFs::new(cx.executor());
3497 fs.insert_tree(path!("/test"), files).await;
3498 Project::test(fs, [path!("/test").as_ref()], cx).await
3499 }
3500
3501 async fn setup_test_environment(
3502 cx: &mut TestAppContext,
3503 project: Entity<Project>,
3504 ) -> (
3505 Entity<Workspace>,
3506 Entity<ThreadStore>,
3507 Entity<Thread>,
3508 Entity<ContextStore>,
3509 Arc<dyn LanguageModel>,
3510 ) {
3511 let (workspace, cx) =
3512 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3513
3514 let thread_store = cx
3515 .update(|_, cx| {
3516 ThreadStore::load(
3517 project.clone(),
3518 cx.new(|_| ToolWorkingSet::default()),
3519 None,
3520 Arc::new(PromptBuilder::new(None).unwrap()),
3521 cx,
3522 )
3523 })
3524 .await
3525 .unwrap();
3526
3527 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3528 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3529
3530 let provider = Arc::new(FakeLanguageModelProvider);
3531 let model = provider.test_model();
3532 let model: Arc<dyn LanguageModel> = Arc::new(model);
3533
3534 cx.update(|_, cx| {
3535 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3536 registry.set_default_model(
3537 Some(ConfiguredModel {
3538 provider: provider.clone(),
3539 model: model.clone(),
3540 }),
3541 cx,
3542 );
3543 registry.set_thread_summary_model(
3544 Some(ConfiguredModel {
3545 provider,
3546 model: model.clone(),
3547 }),
3548 cx,
3549 );
3550 })
3551 });
3552
3553 (workspace, thread_store, thread, context_store, model)
3554 }
3555
3556 async fn add_file_to_context(
3557 project: &Entity<Project>,
3558 context_store: &Entity<ContextStore>,
3559 path: &str,
3560 cx: &mut TestAppContext,
3561 ) -> Result<Entity<language::Buffer>> {
3562 let buffer_path = project
3563 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3564 .unwrap();
3565
3566 let buffer = project
3567 .update(cx, |project, cx| {
3568 project.open_buffer(buffer_path.clone(), cx)
3569 })
3570 .await
3571 .unwrap();
3572
3573 context_store.update(cx, |context_store, cx| {
3574 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3575 });
3576
3577 Ok(buffer)
3578 }
3579}