1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(text)
885 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
886 Some(text)
887 }
888 LanguageModelToolResultContent::Image(_) => {
889 // TODO: We should display image
890 None
891 }
892 }
893 }
894
895 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
896 self.tool_use.tool_result_card(id).cloned()
897 }
898
899 /// Return tools that are both enabled and supported by the model
900 pub fn available_tools(
901 &self,
902 cx: &App,
903 model: Arc<dyn LanguageModel>,
904 ) -> Vec<LanguageModelRequestTool> {
905 if model.supports_tools() {
906 self.tools()
907 .read(cx)
908 .enabled_tools(cx)
909 .into_iter()
910 .filter_map(|tool| {
911 // Skip tools that cannot be supported
912 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
913 Some(LanguageModelRequestTool {
914 name: tool.name(),
915 description: tool.description(),
916 input_schema,
917 })
918 })
919 .collect()
920 } else {
921 Vec::default()
922 }
923 }
924
925 pub fn insert_user_message(
926 &mut self,
927 text: impl Into<String>,
928 loaded_context: ContextLoadResult,
929 git_checkpoint: Option<GitStoreCheckpoint>,
930 creases: Vec<MessageCrease>,
931 cx: &mut Context<Self>,
932 ) -> MessageId {
933 if !loaded_context.referenced_buffers.is_empty() {
934 self.action_log.update(cx, |log, cx| {
935 for buffer in loaded_context.referenced_buffers {
936 log.buffer_read(buffer, cx);
937 }
938 });
939 }
940
941 let message_id = self.insert_message(
942 Role::User,
943 vec![MessageSegment::Text(text.into())],
944 loaded_context.loaded_context,
945 creases,
946 cx,
947 );
948
949 if let Some(git_checkpoint) = git_checkpoint {
950 self.pending_checkpoint = Some(ThreadCheckpoint {
951 message_id,
952 git_checkpoint,
953 });
954 }
955
956 self.auto_capture_telemetry(cx);
957
958 message_id
959 }
960
961 pub fn insert_assistant_message(
962 &mut self,
963 segments: Vec<MessageSegment>,
964 cx: &mut Context<Self>,
965 ) -> MessageId {
966 self.insert_message(
967 Role::Assistant,
968 segments,
969 LoadedContext::default(),
970 Vec::new(),
971 cx,
972 )
973 }
974
975 pub fn insert_message(
976 &mut self,
977 role: Role,
978 segments: Vec<MessageSegment>,
979 loaded_context: LoadedContext,
980 creases: Vec<MessageCrease>,
981 cx: &mut Context<Self>,
982 ) -> MessageId {
983 let id = self.next_message_id.post_inc();
984 self.messages.push(Message {
985 id,
986 role,
987 segments,
988 loaded_context,
989 creases,
990 });
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageAdded(id));
993 id
994 }
995
996 pub fn edit_message(
997 &mut self,
998 id: MessageId,
999 new_role: Role,
1000 new_segments: Vec<MessageSegment>,
1001 loaded_context: Option<LoadedContext>,
1002 checkpoint: Option<GitStoreCheckpoint>,
1003 cx: &mut Context<Self>,
1004 ) -> bool {
1005 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1006 return false;
1007 };
1008 message.role = new_role;
1009 message.segments = new_segments;
1010 if let Some(context) = loaded_context {
1011 message.loaded_context = context;
1012 }
1013 if let Some(git_checkpoint) = checkpoint {
1014 self.checkpoints_by_message.insert(
1015 id,
1016 ThreadCheckpoint {
1017 message_id: id,
1018 git_checkpoint,
1019 },
1020 );
1021 }
1022 self.touch_updated_at();
1023 cx.emit(ThreadEvent::MessageEdited(id));
1024 true
1025 }
1026
1027 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1028 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1029 return false;
1030 };
1031 self.messages.remove(index);
1032 self.touch_updated_at();
1033 cx.emit(ThreadEvent::MessageDeleted(id));
1034 true
1035 }
1036
1037 /// Returns the representation of this [`Thread`] in a textual form.
1038 ///
1039 /// This is the representation we use when attaching a thread as context to another thread.
1040 pub fn text(&self) -> String {
1041 let mut text = String::new();
1042
1043 for message in &self.messages {
1044 text.push_str(match message.role {
1045 language_model::Role::User => "User:",
1046 language_model::Role::Assistant => "Agent:",
1047 language_model::Role::System => "System:",
1048 });
1049 text.push('\n');
1050
1051 for segment in &message.segments {
1052 match segment {
1053 MessageSegment::Text(content) => text.push_str(content),
1054 MessageSegment::Thinking { text: content, .. } => {
1055 text.push_str(&format!("<think>{}</think>", content))
1056 }
1057 MessageSegment::RedactedThinking(_) => {}
1058 }
1059 }
1060 text.push('\n');
1061 }
1062
1063 text
1064 }
1065
1066 /// Serializes this thread into a format for storage or telemetry.
1067 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1068 let initial_project_snapshot = self.initial_project_snapshot.clone();
1069 cx.spawn(async move |this, cx| {
1070 let initial_project_snapshot = initial_project_snapshot.await;
1071 this.read_with(cx, |this, cx| SerializedThread {
1072 version: SerializedThread::VERSION.to_string(),
1073 summary: this.summary().or_default(),
1074 updated_at: this.updated_at(),
1075 messages: this
1076 .messages()
1077 .map(|message| SerializedMessage {
1078 id: message.id,
1079 role: message.role,
1080 segments: message
1081 .segments
1082 .iter()
1083 .map(|segment| match segment {
1084 MessageSegment::Text(text) => {
1085 SerializedMessageSegment::Text { text: text.clone() }
1086 }
1087 MessageSegment::Thinking { text, signature } => {
1088 SerializedMessageSegment::Thinking {
1089 text: text.clone(),
1090 signature: signature.clone(),
1091 }
1092 }
1093 MessageSegment::RedactedThinking(data) => {
1094 SerializedMessageSegment::RedactedThinking {
1095 data: data.clone(),
1096 }
1097 }
1098 })
1099 .collect(),
1100 tool_uses: this
1101 .tool_uses_for_message(message.id, cx)
1102 .into_iter()
1103 .map(|tool_use| SerializedToolUse {
1104 id: tool_use.id,
1105 name: tool_use.name,
1106 input: tool_use.input,
1107 })
1108 .collect(),
1109 tool_results: this
1110 .tool_results_for_message(message.id)
1111 .into_iter()
1112 .map(|tool_result| SerializedToolResult {
1113 tool_use_id: tool_result.tool_use_id.clone(),
1114 is_error: tool_result.is_error,
1115 content: tool_result.content.clone(),
1116 output: tool_result.output.clone(),
1117 })
1118 .collect(),
1119 context: message.loaded_context.text.clone(),
1120 creases: message
1121 .creases
1122 .iter()
1123 .map(|crease| SerializedCrease {
1124 start: crease.range.start,
1125 end: crease.range.end,
1126 icon_path: crease.metadata.icon_path.clone(),
1127 label: crease.metadata.label.clone(),
1128 })
1129 .collect(),
1130 })
1131 .collect(),
1132 initial_project_snapshot,
1133 cumulative_token_usage: this.cumulative_token_usage,
1134 request_token_usage: this.request_token_usage.clone(),
1135 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1136 exceeded_window_error: this.exceeded_window_error.clone(),
1137 model: this
1138 .configured_model
1139 .as_ref()
1140 .map(|model| SerializedLanguageModel {
1141 provider: model.provider.id().0.to_string(),
1142 model: model.model.id().0.to_string(),
1143 }),
1144 completion_mode: Some(this.completion_mode),
1145 })
1146 })
1147 }
1148
1149 pub fn remaining_turns(&self) -> u32 {
1150 self.remaining_turns
1151 }
1152
1153 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1154 self.remaining_turns = remaining_turns;
1155 }
1156
1157 pub fn send_to_model(
1158 &mut self,
1159 model: Arc<dyn LanguageModel>,
1160 window: Option<AnyWindowHandle>,
1161 cx: &mut Context<Self>,
1162 ) {
1163 if self.remaining_turns == 0 {
1164 return;
1165 }
1166
1167 self.remaining_turns -= 1;
1168
1169 let request = self.to_completion_request(model.clone(), cx);
1170
1171 self.stream_completion(request, model, window, cx);
1172 }
1173
1174 pub fn used_tools_since_last_user_message(&self) -> bool {
1175 for message in self.messages.iter().rev() {
1176 if self.tool_use.message_has_tool_results(message.id) {
1177 return true;
1178 } else if message.role == Role::User {
1179 return false;
1180 }
1181 }
1182
1183 false
1184 }
1185
1186 pub fn to_completion_request(
1187 &self,
1188 model: Arc<dyn LanguageModel>,
1189 cx: &mut Context<Self>,
1190 ) -> LanguageModelRequest {
1191 let mut request = LanguageModelRequest {
1192 thread_id: Some(self.id.to_string()),
1193 prompt_id: Some(self.last_prompt_id.to_string()),
1194 mode: None,
1195 messages: vec![],
1196 tools: Vec::new(),
1197 tool_choice: None,
1198 stop: Vec::new(),
1199 temperature: AssistantSettings::temperature_for_model(&model, cx),
1200 };
1201
1202 let available_tools = self.available_tools(cx, model.clone());
1203 let available_tool_names = available_tools
1204 .iter()
1205 .map(|tool| tool.name.clone())
1206 .collect();
1207
1208 let model_context = &ModelContext {
1209 available_tools: available_tool_names,
1210 };
1211
1212 if let Some(project_context) = self.project_context.borrow().as_ref() {
1213 match self
1214 .prompt_builder
1215 .generate_assistant_system_prompt(project_context, model_context)
1216 {
1217 Err(err) => {
1218 let message = format!("{err:?}").into();
1219 log::error!("{message}");
1220 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1221 header: "Error generating system prompt".into(),
1222 message,
1223 }));
1224 }
1225 Ok(system_prompt) => {
1226 request.messages.push(LanguageModelRequestMessage {
1227 role: Role::System,
1228 content: vec![MessageContent::Text(system_prompt)],
1229 cache: true,
1230 });
1231 }
1232 }
1233 } else {
1234 let message = "Context for system prompt unexpectedly not ready.".into();
1235 log::error!("{message}");
1236 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1237 header: "Error generating system prompt".into(),
1238 message,
1239 }));
1240 }
1241
1242 let mut message_ix_to_cache = None;
1243 for message in &self.messages {
1244 let mut request_message = LanguageModelRequestMessage {
1245 role: message.role,
1246 content: Vec::new(),
1247 cache: false,
1248 };
1249
1250 message
1251 .loaded_context
1252 .add_to_request_message(&mut request_message);
1253
1254 for segment in &message.segments {
1255 match segment {
1256 MessageSegment::Text(text) => {
1257 if !text.is_empty() {
1258 request_message
1259 .content
1260 .push(MessageContent::Text(text.into()));
1261 }
1262 }
1263 MessageSegment::Thinking { text, signature } => {
1264 if !text.is_empty() {
1265 request_message.content.push(MessageContent::Thinking {
1266 text: text.into(),
1267 signature: signature.clone(),
1268 });
1269 }
1270 }
1271 MessageSegment::RedactedThinking(data) => {
1272 request_message
1273 .content
1274 .push(MessageContent::RedactedThinking(data.clone()));
1275 }
1276 };
1277 }
1278
1279 let mut cache_message = true;
1280 let mut tool_results_message = LanguageModelRequestMessage {
1281 role: Role::User,
1282 content: Vec::new(),
1283 cache: false,
1284 };
1285 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1286 if let Some(tool_result) = tool_result {
1287 request_message
1288 .content
1289 .push(MessageContent::ToolUse(tool_use.clone()));
1290 tool_results_message
1291 .content
1292 .push(MessageContent::ToolResult(LanguageModelToolResult {
1293 tool_use_id: tool_use.id.clone(),
1294 tool_name: tool_result.tool_name.clone(),
1295 is_error: tool_result.is_error,
1296 content: if tool_result.content.is_empty() {
1297 // Surprisingly, the API fails if we return an empty string here.
1298 // It thinks we are sending a tool use without a tool result.
1299 "<Tool returned an empty string>".into()
1300 } else {
1301 tool_result.content.clone()
1302 },
1303 output: None,
1304 }));
1305 } else {
1306 cache_message = false;
1307 log::debug!(
1308 "skipped tool use {:?} because it is still pending",
1309 tool_use
1310 );
1311 }
1312 }
1313
1314 if cache_message {
1315 message_ix_to_cache = Some(request.messages.len());
1316 }
1317 request.messages.push(request_message);
1318
1319 if !tool_results_message.content.is_empty() {
1320 if cache_message {
1321 message_ix_to_cache = Some(request.messages.len());
1322 }
1323 request.messages.push(tool_results_message);
1324 }
1325 }
1326
1327 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1328 if let Some(message_ix_to_cache) = message_ix_to_cache {
1329 request.messages[message_ix_to_cache].cache = true;
1330 }
1331
1332 self.attached_tracked_files_state(&mut request.messages, cx);
1333
1334 request.tools = available_tools;
1335 request.mode = if model.supports_max_mode() {
1336 Some(self.completion_mode.into())
1337 } else {
1338 Some(CompletionMode::Normal.into())
1339 };
1340
1341 request
1342 }
1343
1344 fn to_summarize_request(
1345 &self,
1346 model: &Arc<dyn LanguageModel>,
1347 added_user_message: String,
1348 cx: &App,
1349 ) -> LanguageModelRequest {
1350 let mut request = LanguageModelRequest {
1351 thread_id: None,
1352 prompt_id: None,
1353 mode: None,
1354 messages: vec![],
1355 tools: Vec::new(),
1356 tool_choice: None,
1357 stop: Vec::new(),
1358 temperature: AssistantSettings::temperature_for_model(model, cx),
1359 };
1360
1361 for message in &self.messages {
1362 let mut request_message = LanguageModelRequestMessage {
1363 role: message.role,
1364 content: Vec::new(),
1365 cache: false,
1366 };
1367
1368 for segment in &message.segments {
1369 match segment {
1370 MessageSegment::Text(text) => request_message
1371 .content
1372 .push(MessageContent::Text(text.clone())),
1373 MessageSegment::Thinking { .. } => {}
1374 MessageSegment::RedactedThinking(_) => {}
1375 }
1376 }
1377
1378 if request_message.content.is_empty() {
1379 continue;
1380 }
1381
1382 request.messages.push(request_message);
1383 }
1384
1385 request.messages.push(LanguageModelRequestMessage {
1386 role: Role::User,
1387 content: vec![MessageContent::Text(added_user_message)],
1388 cache: false,
1389 });
1390
1391 request
1392 }
1393
1394 fn attached_tracked_files_state(
1395 &self,
1396 messages: &mut Vec<LanguageModelRequestMessage>,
1397 cx: &App,
1398 ) {
1399 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1400
1401 let mut stale_message = String::new();
1402
1403 let action_log = self.action_log.read(cx);
1404
1405 for stale_file in action_log.stale_buffers(cx) {
1406 let Some(file) = stale_file.read(cx).file() else {
1407 continue;
1408 };
1409
1410 if stale_message.is_empty() {
1411 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1412 }
1413
1414 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1415 }
1416
1417 let mut content = Vec::with_capacity(2);
1418
1419 if !stale_message.is_empty() {
1420 content.push(stale_message.into());
1421 }
1422
1423 if !content.is_empty() {
1424 let context_message = LanguageModelRequestMessage {
1425 role: Role::User,
1426 content,
1427 cache: false,
1428 };
1429
1430 messages.push(context_message);
1431 }
1432 }
1433
1434 pub fn stream_completion(
1435 &mut self,
1436 request: LanguageModelRequest,
1437 model: Arc<dyn LanguageModel>,
1438 window: Option<AnyWindowHandle>,
1439 cx: &mut Context<Self>,
1440 ) {
1441 self.tool_use_limit_reached = false;
1442
1443 let pending_completion_id = post_inc(&mut self.completion_count);
1444 let mut request_callback_parameters = if self.request_callback.is_some() {
1445 Some((request.clone(), Vec::new()))
1446 } else {
1447 None
1448 };
1449 let prompt_id = self.last_prompt_id.clone();
1450 let tool_use_metadata = ToolUseMetadata {
1451 model: model.clone(),
1452 thread_id: self.id.clone(),
1453 prompt_id: prompt_id.clone(),
1454 };
1455
1456 self.last_received_chunk_at = Some(Instant::now());
1457
1458 let task = cx.spawn(async move |thread, cx| {
1459 let stream_completion_future = model.stream_completion(request, &cx);
1460 let initial_token_usage =
1461 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1462 let stream_completion = async {
1463 let mut events = stream_completion_future.await?;
1464
1465 let mut stop_reason = StopReason::EndTurn;
1466 let mut current_token_usage = TokenUsage::default();
1467
1468 thread
1469 .update(cx, |_thread, cx| {
1470 cx.emit(ThreadEvent::NewRequest);
1471 })
1472 .ok();
1473
1474 let mut request_assistant_message_id = None;
1475
1476 while let Some(event) = events.next().await {
1477 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1478 response_events
1479 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1480 }
1481
1482 thread.update(cx, |thread, cx| {
1483 let event = match event {
1484 Ok(event) => event,
1485 Err(LanguageModelCompletionError::BadInputJson {
1486 id,
1487 tool_name,
1488 raw_input: invalid_input_json,
1489 json_parse_error,
1490 }) => {
1491 thread.receive_invalid_tool_json(
1492 id,
1493 tool_name,
1494 invalid_input_json,
1495 json_parse_error,
1496 window,
1497 cx,
1498 );
1499 return Ok(());
1500 }
1501 Err(LanguageModelCompletionError::Other(error)) => {
1502 return Err(error);
1503 }
1504 };
1505
1506 match event {
1507 LanguageModelCompletionEvent::StartMessage { .. } => {
1508 request_assistant_message_id =
1509 Some(thread.insert_assistant_message(
1510 vec![MessageSegment::Text(String::new())],
1511 cx,
1512 ));
1513 }
1514 LanguageModelCompletionEvent::Stop(reason) => {
1515 stop_reason = reason;
1516 }
1517 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1518 thread.update_token_usage_at_last_message(token_usage);
1519 thread.cumulative_token_usage = thread.cumulative_token_usage
1520 + token_usage
1521 - current_token_usage;
1522 current_token_usage = token_usage;
1523 }
1524 LanguageModelCompletionEvent::Text(chunk) => {
1525 thread.received_chunk();
1526
1527 cx.emit(ThreadEvent::ReceivedTextChunk);
1528 if let Some(last_message) = thread.messages.last_mut() {
1529 if last_message.role == Role::Assistant
1530 && !thread.tool_use.has_tool_results(last_message.id)
1531 {
1532 last_message.push_text(&chunk);
1533 cx.emit(ThreadEvent::StreamedAssistantText(
1534 last_message.id,
1535 chunk,
1536 ));
1537 } else {
1538 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1539 // of a new Assistant response.
1540 //
1541 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1542 // will result in duplicating the text of the chunk in the rendered Markdown.
1543 request_assistant_message_id =
1544 Some(thread.insert_assistant_message(
1545 vec![MessageSegment::Text(chunk.to_string())],
1546 cx,
1547 ));
1548 };
1549 }
1550 }
1551 LanguageModelCompletionEvent::Thinking {
1552 text: chunk,
1553 signature,
1554 } => {
1555 thread.received_chunk();
1556
1557 if let Some(last_message) = thread.messages.last_mut() {
1558 if last_message.role == Role::Assistant
1559 && !thread.tool_use.has_tool_results(last_message.id)
1560 {
1561 last_message.push_thinking(&chunk, signature);
1562 cx.emit(ThreadEvent::StreamedAssistantThinking(
1563 last_message.id,
1564 chunk,
1565 ));
1566 } else {
1567 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1568 // of a new Assistant response.
1569 //
1570 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1571 // will result in duplicating the text of the chunk in the rendered Markdown.
1572 request_assistant_message_id =
1573 Some(thread.insert_assistant_message(
1574 vec![MessageSegment::Thinking {
1575 text: chunk.to_string(),
1576 signature,
1577 }],
1578 cx,
1579 ));
1580 };
1581 }
1582 }
1583 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1584 let last_assistant_message_id = request_assistant_message_id
1585 .unwrap_or_else(|| {
1586 let new_assistant_message_id =
1587 thread.insert_assistant_message(vec![], cx);
1588 request_assistant_message_id =
1589 Some(new_assistant_message_id);
1590 new_assistant_message_id
1591 });
1592
1593 let tool_use_id = tool_use.id.clone();
1594 let streamed_input = if tool_use.is_input_complete {
1595 None
1596 } else {
1597 Some((&tool_use.input).clone())
1598 };
1599
1600 let ui_text = thread.tool_use.request_tool_use(
1601 last_assistant_message_id,
1602 tool_use,
1603 tool_use_metadata.clone(),
1604 cx,
1605 );
1606
1607 if let Some(input) = streamed_input {
1608 cx.emit(ThreadEvent::StreamedToolUse {
1609 tool_use_id,
1610 ui_text,
1611 input,
1612 });
1613 }
1614 }
1615 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1616 if let Some(completion) = thread
1617 .pending_completions
1618 .iter_mut()
1619 .find(|completion| completion.id == pending_completion_id)
1620 {
1621 match status_update {
1622 CompletionRequestStatus::Queued {
1623 position,
1624 } => {
1625 completion.queue_state = QueueState::Queued { position };
1626 }
1627 CompletionRequestStatus::Started => {
1628 completion.queue_state = QueueState::Started;
1629 }
1630 CompletionRequestStatus::Failed {
1631 code, message, request_id
1632 } => {
1633 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1634 }
1635 CompletionRequestStatus::UsageUpdated {
1636 amount, limit
1637 } => {
1638 let usage = RequestUsage { limit, amount: amount as i32 };
1639
1640 thread.last_usage = Some(usage);
1641 }
1642 CompletionRequestStatus::ToolUseLimitReached => {
1643 thread.tool_use_limit_reached = true;
1644 }
1645 }
1646 }
1647 }
1648 }
1649
1650 thread.touch_updated_at();
1651 cx.emit(ThreadEvent::StreamedCompletion);
1652 cx.notify();
1653
1654 thread.auto_capture_telemetry(cx);
1655 Ok(())
1656 })??;
1657
1658 smol::future::yield_now().await;
1659 }
1660
1661 thread.update(cx, |thread, cx| {
1662 thread.last_received_chunk_at = None;
1663 thread
1664 .pending_completions
1665 .retain(|completion| completion.id != pending_completion_id);
1666
1667 // If there is a response without tool use, summarize the message. Otherwise,
1668 // allow two tool uses before summarizing.
1669 if matches!(thread.summary, ThreadSummary::Pending)
1670 && thread.messages.len() >= 2
1671 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1672 {
1673 thread.summarize(cx);
1674 }
1675 })?;
1676
1677 anyhow::Ok(stop_reason)
1678 };
1679
1680 let result = stream_completion.await;
1681
1682 thread
1683 .update(cx, |thread, cx| {
1684 thread.finalize_pending_checkpoint(cx);
1685 match result.as_ref() {
1686 Ok(stop_reason) => match stop_reason {
1687 StopReason::ToolUse => {
1688 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1689 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1690 }
1691 StopReason::EndTurn | StopReason::MaxTokens => {
1692 thread.project.update(cx, |project, cx| {
1693 project.set_agent_location(None, cx);
1694 });
1695 }
1696 StopReason::Refusal => {
1697 thread.project.update(cx, |project, cx| {
1698 project.set_agent_location(None, cx);
1699 });
1700
1701 // Remove the turn that was refused.
1702 //
1703 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1704 {
1705 let mut messages_to_remove = Vec::new();
1706
1707 for (ix, message) in thread.messages.iter().enumerate().rev() {
1708 messages_to_remove.push(message.id);
1709
1710 if message.role == Role::User {
1711 if ix == 0 {
1712 break;
1713 }
1714
1715 if let Some(prev_message) = thread.messages.get(ix - 1) {
1716 if prev_message.role == Role::Assistant {
1717 break;
1718 }
1719 }
1720 }
1721 }
1722
1723 for message_id in messages_to_remove {
1724 thread.delete_message(message_id, cx);
1725 }
1726 }
1727
1728 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1729 header: "Language model refusal".into(),
1730 message: "Model refused to generate content for safety reasons.".into(),
1731 }));
1732 }
1733 },
1734 Err(error) => {
1735 thread.project.update(cx, |project, cx| {
1736 project.set_agent_location(None, cx);
1737 });
1738
1739 if error.is::<PaymentRequiredError>() {
1740 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1741 } else if let Some(error) =
1742 error.downcast_ref::<ModelRequestLimitReachedError>()
1743 {
1744 cx.emit(ThreadEvent::ShowError(
1745 ThreadError::ModelRequestLimitReached { plan: error.plan },
1746 ));
1747 } else if let Some(known_error) =
1748 error.downcast_ref::<LanguageModelKnownError>()
1749 {
1750 match known_error {
1751 LanguageModelKnownError::ContextWindowLimitExceeded {
1752 tokens,
1753 } => {
1754 thread.exceeded_window_error = Some(ExceededWindowError {
1755 model_id: model.id(),
1756 token_count: *tokens,
1757 });
1758 cx.notify();
1759 }
1760 }
1761 } else {
1762 let error_message = error
1763 .chain()
1764 .map(|err| err.to_string())
1765 .collect::<Vec<_>>()
1766 .join("\n");
1767 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1768 header: "Error interacting with language model".into(),
1769 message: SharedString::from(error_message.clone()),
1770 }));
1771 }
1772
1773 thread.cancel_last_completion(window, cx);
1774 }
1775 }
1776 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1777
1778 if let Some((request_callback, (request, response_events))) = thread
1779 .request_callback
1780 .as_mut()
1781 .zip(request_callback_parameters.as_ref())
1782 {
1783 request_callback(request, response_events);
1784 }
1785
1786 thread.auto_capture_telemetry(cx);
1787
1788 if let Ok(initial_usage) = initial_token_usage {
1789 let usage = thread.cumulative_token_usage - initial_usage;
1790
1791 telemetry::event!(
1792 "Assistant Thread Completion",
1793 thread_id = thread.id().to_string(),
1794 prompt_id = prompt_id,
1795 model = model.telemetry_id(),
1796 model_provider = model.provider_id().to_string(),
1797 input_tokens = usage.input_tokens,
1798 output_tokens = usage.output_tokens,
1799 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1800 cache_read_input_tokens = usage.cache_read_input_tokens,
1801 );
1802 }
1803 })
1804 .ok();
1805 });
1806
1807 self.pending_completions.push(PendingCompletion {
1808 id: pending_completion_id,
1809 queue_state: QueueState::Sending,
1810 _task: task,
1811 });
1812 }
1813
1814 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1815 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1816 println!("No thread summary model");
1817 return;
1818 };
1819
1820 if !model.provider.is_authenticated(cx) {
1821 return;
1822 }
1823
1824 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1825 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1826 If the conversation is about a specific subject, include it in the title. \
1827 Be descriptive. DO NOT speak in the first person.";
1828
1829 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1830
1831 self.summary = ThreadSummary::Generating;
1832
1833 self.pending_summary = cx.spawn(async move |this, cx| {
1834 let result = async {
1835 let mut messages = model.model.stream_completion(request, &cx).await?;
1836
1837 let mut new_summary = String::new();
1838 while let Some(event) = messages.next().await {
1839 let Ok(event) = event else {
1840 continue;
1841 };
1842 let text = match event {
1843 LanguageModelCompletionEvent::Text(text) => text,
1844 LanguageModelCompletionEvent::StatusUpdate(
1845 CompletionRequestStatus::UsageUpdated { amount, limit },
1846 ) => {
1847 this.update(cx, |thread, _cx| {
1848 thread.last_usage = Some(RequestUsage {
1849 limit,
1850 amount: amount as i32,
1851 });
1852 })?;
1853 continue;
1854 }
1855 _ => continue,
1856 };
1857
1858 let mut lines = text.lines();
1859 new_summary.extend(lines.next());
1860
1861 // Stop if the LLM generated multiple lines.
1862 if lines.next().is_some() {
1863 break;
1864 }
1865 }
1866
1867 anyhow::Ok(new_summary)
1868 }
1869 .await;
1870
1871 this.update(cx, |this, cx| {
1872 match result {
1873 Ok(new_summary) => {
1874 if new_summary.is_empty() {
1875 this.summary = ThreadSummary::Error;
1876 } else {
1877 this.summary = ThreadSummary::Ready(new_summary.into());
1878 }
1879 }
1880 Err(err) => {
1881 this.summary = ThreadSummary::Error;
1882 log::error!("Failed to generate thread summary: {}", err);
1883 }
1884 }
1885 cx.emit(ThreadEvent::SummaryGenerated);
1886 })
1887 .log_err()?;
1888
1889 Some(())
1890 });
1891 }
1892
1893 pub fn start_generating_detailed_summary_if_needed(
1894 &mut self,
1895 thread_store: WeakEntity<ThreadStore>,
1896 cx: &mut Context<Self>,
1897 ) {
1898 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1899 return;
1900 };
1901
1902 match &*self.detailed_summary_rx.borrow() {
1903 DetailedSummaryState::Generating { message_id, .. }
1904 | DetailedSummaryState::Generated { message_id, .. }
1905 if *message_id == last_message_id =>
1906 {
1907 // Already up-to-date
1908 return;
1909 }
1910 _ => {}
1911 }
1912
1913 let Some(ConfiguredModel { model, provider }) =
1914 LanguageModelRegistry::read_global(cx).thread_summary_model()
1915 else {
1916 return;
1917 };
1918
1919 if !provider.is_authenticated(cx) {
1920 return;
1921 }
1922
1923 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1924 1. A brief overview of what was discussed\n\
1925 2. Key facts or information discovered\n\
1926 3. Outcomes or conclusions reached\n\
1927 4. Any action items or next steps if any\n\
1928 Format it in Markdown with headings and bullet points.";
1929
1930 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1931
1932 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1933 message_id: last_message_id,
1934 };
1935
1936 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1937 // be better to allow the old task to complete, but this would require logic for choosing
1938 // which result to prefer (the old task could complete after the new one, resulting in a
1939 // stale summary).
1940 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1941 let stream = model.stream_completion_text(request, &cx);
1942 let Some(mut messages) = stream.await.log_err() else {
1943 thread
1944 .update(cx, |thread, _cx| {
1945 *thread.detailed_summary_tx.borrow_mut() =
1946 DetailedSummaryState::NotGenerated;
1947 })
1948 .ok()?;
1949 return None;
1950 };
1951
1952 let mut new_detailed_summary = String::new();
1953
1954 while let Some(chunk) = messages.stream.next().await {
1955 if let Some(chunk) = chunk.log_err() {
1956 new_detailed_summary.push_str(&chunk);
1957 }
1958 }
1959
1960 thread
1961 .update(cx, |thread, _cx| {
1962 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1963 text: new_detailed_summary.into(),
1964 message_id: last_message_id,
1965 };
1966 })
1967 .ok()?;
1968
1969 // Save thread so its summary can be reused later
1970 if let Some(thread) = thread.upgrade() {
1971 if let Ok(Ok(save_task)) = cx.update(|cx| {
1972 thread_store
1973 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1974 }) {
1975 save_task.await.log_err();
1976 }
1977 }
1978
1979 Some(())
1980 });
1981 }
1982
1983 pub async fn wait_for_detailed_summary_or_text(
1984 this: &Entity<Self>,
1985 cx: &mut AsyncApp,
1986 ) -> Option<SharedString> {
1987 let mut detailed_summary_rx = this
1988 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1989 .ok()?;
1990 loop {
1991 match detailed_summary_rx.recv().await? {
1992 DetailedSummaryState::Generating { .. } => {}
1993 DetailedSummaryState::NotGenerated => {
1994 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1995 }
1996 DetailedSummaryState::Generated { text, .. } => return Some(text),
1997 }
1998 }
1999 }
2000
2001 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2002 self.detailed_summary_rx
2003 .borrow()
2004 .text()
2005 .unwrap_or_else(|| self.text().into())
2006 }
2007
2008 pub fn is_generating_detailed_summary(&self) -> bool {
2009 matches!(
2010 &*self.detailed_summary_rx.borrow(),
2011 DetailedSummaryState::Generating { .. }
2012 )
2013 }
2014
2015 pub fn use_pending_tools(
2016 &mut self,
2017 window: Option<AnyWindowHandle>,
2018 cx: &mut Context<Self>,
2019 model: Arc<dyn LanguageModel>,
2020 ) -> Vec<PendingToolUse> {
2021 self.auto_capture_telemetry(cx);
2022 let request = Arc::new(self.to_completion_request(model.clone(), cx));
2023 let pending_tool_uses = self
2024 .tool_use
2025 .pending_tool_uses()
2026 .into_iter()
2027 .filter(|tool_use| tool_use.status.is_idle())
2028 .cloned()
2029 .collect::<Vec<_>>();
2030
2031 for tool_use in pending_tool_uses.iter() {
2032 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2033 if tool.needs_confirmation(&tool_use.input, cx)
2034 && !AssistantSettings::get_global(cx).always_allow_tool_actions
2035 {
2036 self.tool_use.confirm_tool_use(
2037 tool_use.id.clone(),
2038 tool_use.ui_text.clone(),
2039 tool_use.input.clone(),
2040 request.clone(),
2041 tool,
2042 );
2043 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2044 } else {
2045 self.run_tool(
2046 tool_use.id.clone(),
2047 tool_use.ui_text.clone(),
2048 tool_use.input.clone(),
2049 request.clone(),
2050 tool,
2051 model.clone(),
2052 window,
2053 cx,
2054 );
2055 }
2056 } else {
2057 self.handle_hallucinated_tool_use(
2058 tool_use.id.clone(),
2059 tool_use.name.clone(),
2060 window,
2061 cx,
2062 );
2063 }
2064 }
2065
2066 pending_tool_uses
2067 }
2068
2069 pub fn handle_hallucinated_tool_use(
2070 &mut self,
2071 tool_use_id: LanguageModelToolUseId,
2072 hallucinated_tool_name: Arc<str>,
2073 window: Option<AnyWindowHandle>,
2074 cx: &mut Context<Thread>,
2075 ) {
2076 let available_tools = self.tools.read(cx).enabled_tools(cx);
2077
2078 let tool_list = available_tools
2079 .iter()
2080 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2081 .collect::<Vec<_>>()
2082 .join("\n");
2083
2084 let error_message = format!(
2085 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2086 hallucinated_tool_name, tool_list
2087 );
2088
2089 let pending_tool_use = self.tool_use.insert_tool_output(
2090 tool_use_id.clone(),
2091 hallucinated_tool_name,
2092 Err(anyhow!("Missing tool call: {error_message}")),
2093 self.configured_model.as_ref(),
2094 );
2095
2096 cx.emit(ThreadEvent::MissingToolUse {
2097 tool_use_id: tool_use_id.clone(),
2098 ui_text: error_message.into(),
2099 });
2100
2101 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2102 }
2103
2104 pub fn receive_invalid_tool_json(
2105 &mut self,
2106 tool_use_id: LanguageModelToolUseId,
2107 tool_name: Arc<str>,
2108 invalid_json: Arc<str>,
2109 error: String,
2110 window: Option<AnyWindowHandle>,
2111 cx: &mut Context<Thread>,
2112 ) {
2113 log::error!("The model returned invalid input JSON: {invalid_json}");
2114
2115 let pending_tool_use = self.tool_use.insert_tool_output(
2116 tool_use_id.clone(),
2117 tool_name,
2118 Err(anyhow!("Error parsing input JSON: {error}")),
2119 self.configured_model.as_ref(),
2120 );
2121 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2122 pending_tool_use.ui_text.clone()
2123 } else {
2124 log::error!(
2125 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2126 );
2127 format!("Unknown tool {}", tool_use_id).into()
2128 };
2129
2130 cx.emit(ThreadEvent::InvalidToolInput {
2131 tool_use_id: tool_use_id.clone(),
2132 ui_text,
2133 invalid_input_json: invalid_json,
2134 });
2135
2136 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2137 }
2138
2139 pub fn run_tool(
2140 &mut self,
2141 tool_use_id: LanguageModelToolUseId,
2142 ui_text: impl Into<SharedString>,
2143 input: serde_json::Value,
2144 request: Arc<LanguageModelRequest>,
2145 tool: Arc<dyn Tool>,
2146 model: Arc<dyn LanguageModel>,
2147 window: Option<AnyWindowHandle>,
2148 cx: &mut Context<Thread>,
2149 ) {
2150 let task =
2151 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2152 self.tool_use
2153 .run_pending_tool(tool_use_id, ui_text.into(), task);
2154 }
2155
2156 fn spawn_tool_use(
2157 &mut self,
2158 tool_use_id: LanguageModelToolUseId,
2159 request: Arc<LanguageModelRequest>,
2160 input: serde_json::Value,
2161 tool: Arc<dyn Tool>,
2162 model: Arc<dyn LanguageModel>,
2163 window: Option<AnyWindowHandle>,
2164 cx: &mut Context<Thread>,
2165 ) -> Task<()> {
2166 let tool_name: Arc<str> = tool.name().into();
2167
2168 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2169 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2170 } else {
2171 tool.run(
2172 input,
2173 request,
2174 self.project.clone(),
2175 self.action_log.clone(),
2176 model,
2177 window,
2178 cx,
2179 )
2180 };
2181
2182 // Store the card separately if it exists
2183 if let Some(card) = tool_result.card.clone() {
2184 self.tool_use
2185 .insert_tool_result_card(tool_use_id.clone(), card);
2186 }
2187
2188 cx.spawn({
2189 async move |thread: WeakEntity<Thread>, cx| {
2190 let output = tool_result.output.await;
2191
2192 thread
2193 .update(cx, |thread, cx| {
2194 let pending_tool_use = thread.tool_use.insert_tool_output(
2195 tool_use_id.clone(),
2196 tool_name,
2197 output,
2198 thread.configured_model.as_ref(),
2199 );
2200 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2201 })
2202 .ok();
2203 }
2204 })
2205 }
2206
2207 fn tool_finished(
2208 &mut self,
2209 tool_use_id: LanguageModelToolUseId,
2210 pending_tool_use: Option<PendingToolUse>,
2211 canceled: bool,
2212 window: Option<AnyWindowHandle>,
2213 cx: &mut Context<Self>,
2214 ) {
2215 if self.all_tools_finished() {
2216 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2217 if !canceled {
2218 self.send_to_model(model.clone(), window, cx);
2219 }
2220 self.auto_capture_telemetry(cx);
2221 }
2222 }
2223
2224 cx.emit(ThreadEvent::ToolFinished {
2225 tool_use_id,
2226 pending_tool_use,
2227 });
2228 }
2229
2230 /// Cancels the last pending completion, if there are any pending.
2231 ///
2232 /// Returns whether a completion was canceled.
2233 pub fn cancel_last_completion(
2234 &mut self,
2235 window: Option<AnyWindowHandle>,
2236 cx: &mut Context<Self>,
2237 ) -> bool {
2238 let mut canceled = self.pending_completions.pop().is_some();
2239
2240 for pending_tool_use in self.tool_use.cancel_pending() {
2241 canceled = true;
2242 self.tool_finished(
2243 pending_tool_use.id.clone(),
2244 Some(pending_tool_use),
2245 true,
2246 window,
2247 cx,
2248 );
2249 }
2250
2251 self.finalize_pending_checkpoint(cx);
2252
2253 if canceled {
2254 cx.emit(ThreadEvent::CompletionCanceled);
2255 }
2256
2257 canceled
2258 }
2259
2260 /// Signals that any in-progress editing should be canceled.
2261 ///
2262 /// This method is used to notify listeners (like ActiveThread) that
2263 /// they should cancel any editing operations.
2264 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2265 cx.emit(ThreadEvent::CancelEditing);
2266 }
2267
2268 pub fn feedback(&self) -> Option<ThreadFeedback> {
2269 self.feedback
2270 }
2271
2272 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2273 self.message_feedback.get(&message_id).copied()
2274 }
2275
2276 pub fn report_message_feedback(
2277 &mut self,
2278 message_id: MessageId,
2279 feedback: ThreadFeedback,
2280 cx: &mut Context<Self>,
2281 ) -> Task<Result<()>> {
2282 if self.message_feedback.get(&message_id) == Some(&feedback) {
2283 return Task::ready(Ok(()));
2284 }
2285
2286 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2287 let serialized_thread = self.serialize(cx);
2288 let thread_id = self.id().clone();
2289 let client = self.project.read(cx).client();
2290
2291 let enabled_tool_names: Vec<String> = self
2292 .tools()
2293 .read(cx)
2294 .enabled_tools(cx)
2295 .iter()
2296 .map(|tool| tool.name())
2297 .collect();
2298
2299 self.message_feedback.insert(message_id, feedback);
2300
2301 cx.notify();
2302
2303 let message_content = self
2304 .message(message_id)
2305 .map(|msg| msg.to_string())
2306 .unwrap_or_default();
2307
2308 cx.background_spawn(async move {
2309 let final_project_snapshot = final_project_snapshot.await;
2310 let serialized_thread = serialized_thread.await?;
2311 let thread_data =
2312 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2313
2314 let rating = match feedback {
2315 ThreadFeedback::Positive => "positive",
2316 ThreadFeedback::Negative => "negative",
2317 };
2318 telemetry::event!(
2319 "Assistant Thread Rated",
2320 rating,
2321 thread_id,
2322 enabled_tool_names,
2323 message_id = message_id.0,
2324 message_content,
2325 thread_data,
2326 final_project_snapshot
2327 );
2328 client.telemetry().flush_events().await;
2329
2330 Ok(())
2331 })
2332 }
2333
2334 pub fn report_feedback(
2335 &mut self,
2336 feedback: ThreadFeedback,
2337 cx: &mut Context<Self>,
2338 ) -> Task<Result<()>> {
2339 let last_assistant_message_id = self
2340 .messages
2341 .iter()
2342 .rev()
2343 .find(|msg| msg.role == Role::Assistant)
2344 .map(|msg| msg.id);
2345
2346 if let Some(message_id) = last_assistant_message_id {
2347 self.report_message_feedback(message_id, feedback, cx)
2348 } else {
2349 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2350 let serialized_thread = self.serialize(cx);
2351 let thread_id = self.id().clone();
2352 let client = self.project.read(cx).client();
2353 self.feedback = Some(feedback);
2354 cx.notify();
2355
2356 cx.background_spawn(async move {
2357 let final_project_snapshot = final_project_snapshot.await;
2358 let serialized_thread = serialized_thread.await?;
2359 let thread_data = serde_json::to_value(serialized_thread)
2360 .unwrap_or_else(|_| serde_json::Value::Null);
2361
2362 let rating = match feedback {
2363 ThreadFeedback::Positive => "positive",
2364 ThreadFeedback::Negative => "negative",
2365 };
2366 telemetry::event!(
2367 "Assistant Thread Rated",
2368 rating,
2369 thread_id,
2370 thread_data,
2371 final_project_snapshot
2372 );
2373 client.telemetry().flush_events().await;
2374
2375 Ok(())
2376 })
2377 }
2378 }
2379
2380 /// Create a snapshot of the current project state including git information and unsaved buffers.
2381 fn project_snapshot(
2382 project: Entity<Project>,
2383 cx: &mut Context<Self>,
2384 ) -> Task<Arc<ProjectSnapshot>> {
2385 let git_store = project.read(cx).git_store().clone();
2386 let worktree_snapshots: Vec<_> = project
2387 .read(cx)
2388 .visible_worktrees(cx)
2389 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2390 .collect();
2391
2392 cx.spawn(async move |_, cx| {
2393 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2394
2395 let mut unsaved_buffers = Vec::new();
2396 cx.update(|app_cx| {
2397 let buffer_store = project.read(app_cx).buffer_store();
2398 for buffer_handle in buffer_store.read(app_cx).buffers() {
2399 let buffer = buffer_handle.read(app_cx);
2400 if buffer.is_dirty() {
2401 if let Some(file) = buffer.file() {
2402 let path = file.path().to_string_lossy().to_string();
2403 unsaved_buffers.push(path);
2404 }
2405 }
2406 }
2407 })
2408 .ok();
2409
2410 Arc::new(ProjectSnapshot {
2411 worktree_snapshots,
2412 unsaved_buffer_paths: unsaved_buffers,
2413 timestamp: Utc::now(),
2414 })
2415 })
2416 }
2417
2418 fn worktree_snapshot(
2419 worktree: Entity<project::Worktree>,
2420 git_store: Entity<GitStore>,
2421 cx: &App,
2422 ) -> Task<WorktreeSnapshot> {
2423 cx.spawn(async move |cx| {
2424 // Get worktree path and snapshot
2425 let worktree_info = cx.update(|app_cx| {
2426 let worktree = worktree.read(app_cx);
2427 let path = worktree.abs_path().to_string_lossy().to_string();
2428 let snapshot = worktree.snapshot();
2429 (path, snapshot)
2430 });
2431
2432 let Ok((worktree_path, _snapshot)) = worktree_info else {
2433 return WorktreeSnapshot {
2434 worktree_path: String::new(),
2435 git_state: None,
2436 };
2437 };
2438
2439 let git_state = git_store
2440 .update(cx, |git_store, cx| {
2441 git_store
2442 .repositories()
2443 .values()
2444 .find(|repo| {
2445 repo.read(cx)
2446 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2447 .is_some()
2448 })
2449 .cloned()
2450 })
2451 .ok()
2452 .flatten()
2453 .map(|repo| {
2454 repo.update(cx, |repo, _| {
2455 let current_branch =
2456 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2457 repo.send_job(None, |state, _| async move {
2458 let RepositoryState::Local { backend, .. } = state else {
2459 return GitState {
2460 remote_url: None,
2461 head_sha: None,
2462 current_branch,
2463 diff: None,
2464 };
2465 };
2466
2467 let remote_url = backend.remote_url("origin");
2468 let head_sha = backend.head_sha().await;
2469 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2470
2471 GitState {
2472 remote_url,
2473 head_sha,
2474 current_branch,
2475 diff,
2476 }
2477 })
2478 })
2479 });
2480
2481 let git_state = match git_state {
2482 Some(git_state) => match git_state.ok() {
2483 Some(git_state) => git_state.await.ok(),
2484 None => None,
2485 },
2486 None => None,
2487 };
2488
2489 WorktreeSnapshot {
2490 worktree_path,
2491 git_state,
2492 }
2493 })
2494 }
2495
2496 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2497 let mut markdown = Vec::new();
2498
2499 let summary = self.summary().or_default();
2500 writeln!(markdown, "# {summary}\n")?;
2501
2502 for message in self.messages() {
2503 writeln!(
2504 markdown,
2505 "## {role}\n",
2506 role = match message.role {
2507 Role::User => "User",
2508 Role::Assistant => "Agent",
2509 Role::System => "System",
2510 }
2511 )?;
2512
2513 if !message.loaded_context.text.is_empty() {
2514 writeln!(markdown, "{}", message.loaded_context.text)?;
2515 }
2516
2517 if !message.loaded_context.images.is_empty() {
2518 writeln!(
2519 markdown,
2520 "\n{} images attached as context.\n",
2521 message.loaded_context.images.len()
2522 )?;
2523 }
2524
2525 for segment in &message.segments {
2526 match segment {
2527 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2528 MessageSegment::Thinking { text, .. } => {
2529 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2530 }
2531 MessageSegment::RedactedThinking(_) => {}
2532 }
2533 }
2534
2535 for tool_use in self.tool_uses_for_message(message.id, cx) {
2536 writeln!(
2537 markdown,
2538 "**Use Tool: {} ({})**",
2539 tool_use.name, tool_use.id
2540 )?;
2541 writeln!(markdown, "```json")?;
2542 writeln!(
2543 markdown,
2544 "{}",
2545 serde_json::to_string_pretty(&tool_use.input)?
2546 )?;
2547 writeln!(markdown, "```")?;
2548 }
2549
2550 for tool_result in self.tool_results_for_message(message.id) {
2551 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2552 if tool_result.is_error {
2553 write!(markdown, " (Error)")?;
2554 }
2555
2556 writeln!(markdown, "**\n")?;
2557 match &tool_result.content {
2558 LanguageModelToolResultContent::Text(text)
2559 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2560 text,
2561 ..
2562 }) => {
2563 writeln!(markdown, "{text}")?;
2564 }
2565 LanguageModelToolResultContent::Image(image) => {
2566 writeln!(markdown, "", image.source)?;
2567 }
2568 }
2569
2570 if let Some(output) = tool_result.output.as_ref() {
2571 writeln!(
2572 markdown,
2573 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2574 serde_json::to_string_pretty(output)?
2575 )?;
2576 }
2577 }
2578 }
2579
2580 Ok(String::from_utf8_lossy(&markdown).to_string())
2581 }
2582
2583 pub fn keep_edits_in_range(
2584 &mut self,
2585 buffer: Entity<language::Buffer>,
2586 buffer_range: Range<language::Anchor>,
2587 cx: &mut Context<Self>,
2588 ) {
2589 self.action_log.update(cx, |action_log, cx| {
2590 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2591 });
2592 }
2593
2594 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2595 self.action_log
2596 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2597 }
2598
2599 pub fn reject_edits_in_ranges(
2600 &mut self,
2601 buffer: Entity<language::Buffer>,
2602 buffer_ranges: Vec<Range<language::Anchor>>,
2603 cx: &mut Context<Self>,
2604 ) -> Task<Result<()>> {
2605 self.action_log.update(cx, |action_log, cx| {
2606 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2607 })
2608 }
2609
2610 pub fn action_log(&self) -> &Entity<ActionLog> {
2611 &self.action_log
2612 }
2613
2614 pub fn project(&self) -> &Entity<Project> {
2615 &self.project
2616 }
2617
2618 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2619 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2620 return;
2621 }
2622
2623 let now = Instant::now();
2624 if let Some(last) = self.last_auto_capture_at {
2625 if now.duration_since(last).as_secs() < 10 {
2626 return;
2627 }
2628 }
2629
2630 self.last_auto_capture_at = Some(now);
2631
2632 let thread_id = self.id().clone();
2633 let github_login = self
2634 .project
2635 .read(cx)
2636 .user_store()
2637 .read(cx)
2638 .current_user()
2639 .map(|user| user.github_login.clone());
2640 let client = self.project.read(cx).client();
2641 let serialize_task = self.serialize(cx);
2642
2643 cx.background_executor()
2644 .spawn(async move {
2645 if let Ok(serialized_thread) = serialize_task.await {
2646 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2647 telemetry::event!(
2648 "Agent Thread Auto-Captured",
2649 thread_id = thread_id.to_string(),
2650 thread_data = thread_data,
2651 auto_capture_reason = "tracked_user",
2652 github_login = github_login
2653 );
2654
2655 client.telemetry().flush_events().await;
2656 }
2657 }
2658 })
2659 .detach();
2660 }
2661
2662 pub fn cumulative_token_usage(&self) -> TokenUsage {
2663 self.cumulative_token_usage
2664 }
2665
2666 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2667 let Some(model) = self.configured_model.as_ref() else {
2668 return TotalTokenUsage::default();
2669 };
2670
2671 let max = model.model.max_token_count();
2672
2673 let index = self
2674 .messages
2675 .iter()
2676 .position(|msg| msg.id == message_id)
2677 .unwrap_or(0);
2678
2679 if index == 0 {
2680 return TotalTokenUsage { total: 0, max };
2681 }
2682
2683 let token_usage = &self
2684 .request_token_usage
2685 .get(index - 1)
2686 .cloned()
2687 .unwrap_or_default();
2688
2689 TotalTokenUsage {
2690 total: token_usage.total_tokens() as usize,
2691 max,
2692 }
2693 }
2694
2695 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2696 let model = self.configured_model.as_ref()?;
2697
2698 let max = model.model.max_token_count();
2699
2700 if let Some(exceeded_error) = &self.exceeded_window_error {
2701 if model.model.id() == exceeded_error.model_id {
2702 return Some(TotalTokenUsage {
2703 total: exceeded_error.token_count,
2704 max,
2705 });
2706 }
2707 }
2708
2709 let total = self
2710 .token_usage_at_last_message()
2711 .unwrap_or_default()
2712 .total_tokens() as usize;
2713
2714 Some(TotalTokenUsage { total, max })
2715 }
2716
2717 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2718 self.request_token_usage
2719 .get(self.messages.len().saturating_sub(1))
2720 .or_else(|| self.request_token_usage.last())
2721 .cloned()
2722 }
2723
2724 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2725 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2726 self.request_token_usage
2727 .resize(self.messages.len(), placeholder);
2728
2729 if let Some(last) = self.request_token_usage.last_mut() {
2730 *last = token_usage;
2731 }
2732 }
2733
2734 pub fn deny_tool_use(
2735 &mut self,
2736 tool_use_id: LanguageModelToolUseId,
2737 tool_name: Arc<str>,
2738 window: Option<AnyWindowHandle>,
2739 cx: &mut Context<Self>,
2740 ) {
2741 let err = Err(anyhow::anyhow!(
2742 "Permission to run tool action denied by user"
2743 ));
2744
2745 self.tool_use.insert_tool_output(
2746 tool_use_id.clone(),
2747 tool_name,
2748 err,
2749 self.configured_model.as_ref(),
2750 );
2751 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2752 }
2753}
2754
2755#[derive(Debug, Clone, Error)]
2756pub enum ThreadError {
2757 #[error("Payment required")]
2758 PaymentRequired,
2759 #[error("Model request limit reached")]
2760 ModelRequestLimitReached { plan: Plan },
2761 #[error("Message {header}: {message}")]
2762 Message {
2763 header: SharedString,
2764 message: SharedString,
2765 },
2766}
2767
2768#[derive(Debug, Clone)]
2769pub enum ThreadEvent {
2770 ShowError(ThreadError),
2771 StreamedCompletion,
2772 ReceivedTextChunk,
2773 NewRequest,
2774 StreamedAssistantText(MessageId, String),
2775 StreamedAssistantThinking(MessageId, String),
2776 StreamedToolUse {
2777 tool_use_id: LanguageModelToolUseId,
2778 ui_text: Arc<str>,
2779 input: serde_json::Value,
2780 },
2781 MissingToolUse {
2782 tool_use_id: LanguageModelToolUseId,
2783 ui_text: Arc<str>,
2784 },
2785 InvalidToolInput {
2786 tool_use_id: LanguageModelToolUseId,
2787 ui_text: Arc<str>,
2788 invalid_input_json: Arc<str>,
2789 },
2790 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2791 MessageAdded(MessageId),
2792 MessageEdited(MessageId),
2793 MessageDeleted(MessageId),
2794 SummaryGenerated,
2795 SummaryChanged,
2796 UsePendingTools {
2797 tool_uses: Vec<PendingToolUse>,
2798 },
2799 ToolFinished {
2800 #[allow(unused)]
2801 tool_use_id: LanguageModelToolUseId,
2802 /// The pending tool use that corresponds to this tool.
2803 pending_tool_use: Option<PendingToolUse>,
2804 },
2805 CheckpointChanged,
2806 ToolConfirmationNeeded,
2807 CancelEditing,
2808 CompletionCanceled,
2809}
2810
2811impl EventEmitter<ThreadEvent> for Thread {}
2812
2813struct PendingCompletion {
2814 id: usize,
2815 queue_state: QueueState,
2816 _task: Task<()>,
2817}
2818
2819#[cfg(test)]
2820mod tests {
2821 use super::*;
2822 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2823 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2824 use assistant_tool::ToolRegistry;
2825 use editor::EditorSettings;
2826 use gpui::TestAppContext;
2827 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2828 use project::{FakeFs, Project};
2829 use prompt_store::PromptBuilder;
2830 use serde_json::json;
2831 use settings::{Settings, SettingsStore};
2832 use std::sync::Arc;
2833 use theme::ThemeSettings;
2834 use util::path;
2835 use workspace::Workspace;
2836
2837 #[gpui::test]
2838 async fn test_message_with_context(cx: &mut TestAppContext) {
2839 init_test_settings(cx);
2840
2841 let project = create_test_project(
2842 cx,
2843 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2844 )
2845 .await;
2846
2847 let (_workspace, _thread_store, thread, context_store, model) =
2848 setup_test_environment(cx, project.clone()).await;
2849
2850 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2851 .await
2852 .unwrap();
2853
2854 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2855 let loaded_context = cx
2856 .update(|cx| load_context(vec![context], &project, &None, cx))
2857 .await;
2858
2859 // Insert user message with context
2860 let message_id = thread.update(cx, |thread, cx| {
2861 thread.insert_user_message(
2862 "Please explain this code",
2863 loaded_context,
2864 None,
2865 Vec::new(),
2866 cx,
2867 )
2868 });
2869
2870 // Check content and context in message object
2871 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2872
2873 // Use different path format strings based on platform for the test
2874 #[cfg(windows)]
2875 let path_part = r"test\code.rs";
2876 #[cfg(not(windows))]
2877 let path_part = "test/code.rs";
2878
2879 let expected_context = format!(
2880 r#"
2881<context>
2882The following items were attached by the user. They are up-to-date and don't need to be re-read.
2883
2884<files>
2885```rs {path_part}
2886fn main() {{
2887 println!("Hello, world!");
2888}}
2889```
2890</files>
2891</context>
2892"#
2893 );
2894
2895 assert_eq!(message.role, Role::User);
2896 assert_eq!(message.segments.len(), 1);
2897 assert_eq!(
2898 message.segments[0],
2899 MessageSegment::Text("Please explain this code".to_string())
2900 );
2901 assert_eq!(message.loaded_context.text, expected_context);
2902
2903 // Check message in request
2904 let request = thread.update(cx, |thread, cx| {
2905 thread.to_completion_request(model.clone(), cx)
2906 });
2907
2908 assert_eq!(request.messages.len(), 2);
2909 let expected_full_message = format!("{}Please explain this code", expected_context);
2910 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2911 }
2912
2913 #[gpui::test]
2914 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2915 init_test_settings(cx);
2916
2917 let project = create_test_project(
2918 cx,
2919 json!({
2920 "file1.rs": "fn function1() {}\n",
2921 "file2.rs": "fn function2() {}\n",
2922 "file3.rs": "fn function3() {}\n",
2923 "file4.rs": "fn function4() {}\n",
2924 }),
2925 )
2926 .await;
2927
2928 let (_, _thread_store, thread, context_store, model) =
2929 setup_test_environment(cx, project.clone()).await;
2930
2931 // First message with context 1
2932 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2933 .await
2934 .unwrap();
2935 let new_contexts = context_store.update(cx, |store, cx| {
2936 store.new_context_for_thread(thread.read(cx), None)
2937 });
2938 assert_eq!(new_contexts.len(), 1);
2939 let loaded_context = cx
2940 .update(|cx| load_context(new_contexts, &project, &None, cx))
2941 .await;
2942 let message1_id = thread.update(cx, |thread, cx| {
2943 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2944 });
2945
2946 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2947 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2948 .await
2949 .unwrap();
2950 let new_contexts = context_store.update(cx, |store, cx| {
2951 store.new_context_for_thread(thread.read(cx), None)
2952 });
2953 assert_eq!(new_contexts.len(), 1);
2954 let loaded_context = cx
2955 .update(|cx| load_context(new_contexts, &project, &None, cx))
2956 .await;
2957 let message2_id = thread.update(cx, |thread, cx| {
2958 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2959 });
2960
2961 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2962 //
2963 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2964 .await
2965 .unwrap();
2966 let new_contexts = context_store.update(cx, |store, cx| {
2967 store.new_context_for_thread(thread.read(cx), None)
2968 });
2969 assert_eq!(new_contexts.len(), 1);
2970 let loaded_context = cx
2971 .update(|cx| load_context(new_contexts, &project, &None, cx))
2972 .await;
2973 let message3_id = thread.update(cx, |thread, cx| {
2974 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2975 });
2976
2977 // Check what contexts are included in each message
2978 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2979 (
2980 thread.message(message1_id).unwrap().clone(),
2981 thread.message(message2_id).unwrap().clone(),
2982 thread.message(message3_id).unwrap().clone(),
2983 )
2984 });
2985
2986 // First message should include context 1
2987 assert!(message1.loaded_context.text.contains("file1.rs"));
2988
2989 // Second message should include only context 2 (not 1)
2990 assert!(!message2.loaded_context.text.contains("file1.rs"));
2991 assert!(message2.loaded_context.text.contains("file2.rs"));
2992
2993 // Third message should include only context 3 (not 1 or 2)
2994 assert!(!message3.loaded_context.text.contains("file1.rs"));
2995 assert!(!message3.loaded_context.text.contains("file2.rs"));
2996 assert!(message3.loaded_context.text.contains("file3.rs"));
2997
2998 // Check entire request to make sure all contexts are properly included
2999 let request = thread.update(cx, |thread, cx| {
3000 thread.to_completion_request(model.clone(), cx)
3001 });
3002
3003 // The request should contain all 3 messages
3004 assert_eq!(request.messages.len(), 4);
3005
3006 // Check that the contexts are properly formatted in each message
3007 assert!(request.messages[1].string_contents().contains("file1.rs"));
3008 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3009 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3010
3011 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3012 assert!(request.messages[2].string_contents().contains("file2.rs"));
3013 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3014
3015 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3016 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3017 assert!(request.messages[3].string_contents().contains("file3.rs"));
3018
3019 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3020 .await
3021 .unwrap();
3022 let new_contexts = context_store.update(cx, |store, cx| {
3023 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3024 });
3025 assert_eq!(new_contexts.len(), 3);
3026 let loaded_context = cx
3027 .update(|cx| load_context(new_contexts, &project, &None, cx))
3028 .await
3029 .loaded_context;
3030
3031 assert!(!loaded_context.text.contains("file1.rs"));
3032 assert!(loaded_context.text.contains("file2.rs"));
3033 assert!(loaded_context.text.contains("file3.rs"));
3034 assert!(loaded_context.text.contains("file4.rs"));
3035
3036 let new_contexts = context_store.update(cx, |store, cx| {
3037 // Remove file4.rs
3038 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3039 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3040 });
3041 assert_eq!(new_contexts.len(), 2);
3042 let loaded_context = cx
3043 .update(|cx| load_context(new_contexts, &project, &None, cx))
3044 .await
3045 .loaded_context;
3046
3047 assert!(!loaded_context.text.contains("file1.rs"));
3048 assert!(loaded_context.text.contains("file2.rs"));
3049 assert!(loaded_context.text.contains("file3.rs"));
3050 assert!(!loaded_context.text.contains("file4.rs"));
3051
3052 let new_contexts = context_store.update(cx, |store, cx| {
3053 // Remove file3.rs
3054 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3055 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3056 });
3057 assert_eq!(new_contexts.len(), 1);
3058 let loaded_context = cx
3059 .update(|cx| load_context(new_contexts, &project, &None, cx))
3060 .await
3061 .loaded_context;
3062
3063 assert!(!loaded_context.text.contains("file1.rs"));
3064 assert!(loaded_context.text.contains("file2.rs"));
3065 assert!(!loaded_context.text.contains("file3.rs"));
3066 assert!(!loaded_context.text.contains("file4.rs"));
3067 }
3068
3069 #[gpui::test]
3070 async fn test_message_without_files(cx: &mut TestAppContext) {
3071 init_test_settings(cx);
3072
3073 let project = create_test_project(
3074 cx,
3075 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3076 )
3077 .await;
3078
3079 let (_, _thread_store, thread, _context_store, model) =
3080 setup_test_environment(cx, project.clone()).await;
3081
3082 // Insert user message without any context (empty context vector)
3083 let message_id = thread.update(cx, |thread, cx| {
3084 thread.insert_user_message(
3085 "What is the best way to learn Rust?",
3086 ContextLoadResult::default(),
3087 None,
3088 Vec::new(),
3089 cx,
3090 )
3091 });
3092
3093 // Check content and context in message object
3094 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3095
3096 // Context should be empty when no files are included
3097 assert_eq!(message.role, Role::User);
3098 assert_eq!(message.segments.len(), 1);
3099 assert_eq!(
3100 message.segments[0],
3101 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3102 );
3103 assert_eq!(message.loaded_context.text, "");
3104
3105 // Check message in request
3106 let request = thread.update(cx, |thread, cx| {
3107 thread.to_completion_request(model.clone(), cx)
3108 });
3109
3110 assert_eq!(request.messages.len(), 2);
3111 assert_eq!(
3112 request.messages[1].string_contents(),
3113 "What is the best way to learn Rust?"
3114 );
3115
3116 // Add second message, also without context
3117 let message2_id = thread.update(cx, |thread, cx| {
3118 thread.insert_user_message(
3119 "Are there any good books?",
3120 ContextLoadResult::default(),
3121 None,
3122 Vec::new(),
3123 cx,
3124 )
3125 });
3126
3127 let message2 =
3128 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3129 assert_eq!(message2.loaded_context.text, "");
3130
3131 // Check that both messages appear in the request
3132 let request = thread.update(cx, |thread, cx| {
3133 thread.to_completion_request(model.clone(), cx)
3134 });
3135
3136 assert_eq!(request.messages.len(), 3);
3137 assert_eq!(
3138 request.messages[1].string_contents(),
3139 "What is the best way to learn Rust?"
3140 );
3141 assert_eq!(
3142 request.messages[2].string_contents(),
3143 "Are there any good books?"
3144 );
3145 }
3146
3147 #[gpui::test]
3148 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3149 init_test_settings(cx);
3150
3151 let project = create_test_project(
3152 cx,
3153 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3154 )
3155 .await;
3156
3157 let (_workspace, _thread_store, thread, context_store, model) =
3158 setup_test_environment(cx, project.clone()).await;
3159
3160 // Open buffer and add it to context
3161 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3162 .await
3163 .unwrap();
3164
3165 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3166 let loaded_context = cx
3167 .update(|cx| load_context(vec![context], &project, &None, cx))
3168 .await;
3169
3170 // Insert user message with the buffer as context
3171 thread.update(cx, |thread, cx| {
3172 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3173 });
3174
3175 // Create a request and check that it doesn't have a stale buffer warning yet
3176 let initial_request = thread.update(cx, |thread, cx| {
3177 thread.to_completion_request(model.clone(), cx)
3178 });
3179
3180 // Make sure we don't have a stale file warning yet
3181 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3182 msg.string_contents()
3183 .contains("These files changed since last read:")
3184 });
3185 assert!(
3186 !has_stale_warning,
3187 "Should not have stale buffer warning before buffer is modified"
3188 );
3189
3190 // Modify the buffer
3191 buffer.update(cx, |buffer, cx| {
3192 // Find a position at the end of line 1
3193 buffer.edit(
3194 [(1..1, "\n println!(\"Added a new line\");\n")],
3195 None,
3196 cx,
3197 );
3198 });
3199
3200 // Insert another user message without context
3201 thread.update(cx, |thread, cx| {
3202 thread.insert_user_message(
3203 "What does the code do now?",
3204 ContextLoadResult::default(),
3205 None,
3206 Vec::new(),
3207 cx,
3208 )
3209 });
3210
3211 // Create a new request and check for the stale buffer warning
3212 let new_request = thread.update(cx, |thread, cx| {
3213 thread.to_completion_request(model.clone(), cx)
3214 });
3215
3216 // We should have a stale file warning as the last message
3217 let last_message = new_request
3218 .messages
3219 .last()
3220 .expect("Request should have messages");
3221
3222 // The last message should be the stale buffer notification
3223 assert_eq!(last_message.role, Role::User);
3224
3225 // Check the exact content of the message
3226 let expected_content = "These files changed since last read:\n- code.rs\n";
3227 assert_eq!(
3228 last_message.string_contents(),
3229 expected_content,
3230 "Last message should be exactly the stale buffer notification"
3231 );
3232 }
3233
3234 #[gpui::test]
3235 async fn test_temperature_setting(cx: &mut TestAppContext) {
3236 init_test_settings(cx);
3237
3238 let project = create_test_project(
3239 cx,
3240 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3241 )
3242 .await;
3243
3244 let (_workspace, _thread_store, thread, _context_store, model) =
3245 setup_test_environment(cx, project.clone()).await;
3246
3247 // Both model and provider
3248 cx.update(|cx| {
3249 AssistantSettings::override_global(
3250 AssistantSettings {
3251 model_parameters: vec![LanguageModelParameters {
3252 provider: Some(model.provider_id().0.to_string().into()),
3253 model: Some(model.id().0.clone()),
3254 temperature: Some(0.66),
3255 }],
3256 ..AssistantSettings::get_global(cx).clone()
3257 },
3258 cx,
3259 );
3260 });
3261
3262 let request = thread.update(cx, |thread, cx| {
3263 thread.to_completion_request(model.clone(), cx)
3264 });
3265 assert_eq!(request.temperature, Some(0.66));
3266
3267 // Only model
3268 cx.update(|cx| {
3269 AssistantSettings::override_global(
3270 AssistantSettings {
3271 model_parameters: vec![LanguageModelParameters {
3272 provider: None,
3273 model: Some(model.id().0.clone()),
3274 temperature: Some(0.66),
3275 }],
3276 ..AssistantSettings::get_global(cx).clone()
3277 },
3278 cx,
3279 );
3280 });
3281
3282 let request = thread.update(cx, |thread, cx| {
3283 thread.to_completion_request(model.clone(), cx)
3284 });
3285 assert_eq!(request.temperature, Some(0.66));
3286
3287 // Only provider
3288 cx.update(|cx| {
3289 AssistantSettings::override_global(
3290 AssistantSettings {
3291 model_parameters: vec![LanguageModelParameters {
3292 provider: Some(model.provider_id().0.to_string().into()),
3293 model: None,
3294 temperature: Some(0.66),
3295 }],
3296 ..AssistantSettings::get_global(cx).clone()
3297 },
3298 cx,
3299 );
3300 });
3301
3302 let request = thread.update(cx, |thread, cx| {
3303 thread.to_completion_request(model.clone(), cx)
3304 });
3305 assert_eq!(request.temperature, Some(0.66));
3306
3307 // Same model name, different provider
3308 cx.update(|cx| {
3309 AssistantSettings::override_global(
3310 AssistantSettings {
3311 model_parameters: vec![LanguageModelParameters {
3312 provider: Some("anthropic".into()),
3313 model: Some(model.id().0.clone()),
3314 temperature: Some(0.66),
3315 }],
3316 ..AssistantSettings::get_global(cx).clone()
3317 },
3318 cx,
3319 );
3320 });
3321
3322 let request = thread.update(cx, |thread, cx| {
3323 thread.to_completion_request(model.clone(), cx)
3324 });
3325 assert_eq!(request.temperature, None);
3326 }
3327
3328 #[gpui::test]
3329 async fn test_thread_summary(cx: &mut TestAppContext) {
3330 init_test_settings(cx);
3331
3332 let project = create_test_project(cx, json!({})).await;
3333
3334 let (_, _thread_store, thread, _context_store, model) =
3335 setup_test_environment(cx, project.clone()).await;
3336
3337 // Initial state should be pending
3338 thread.read_with(cx, |thread, _| {
3339 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3340 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3341 });
3342
3343 // Manually setting the summary should not be allowed in this state
3344 thread.update(cx, |thread, cx| {
3345 thread.set_summary("This should not work", cx);
3346 });
3347
3348 thread.read_with(cx, |thread, _| {
3349 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3350 });
3351
3352 // Send a message
3353 thread.update(cx, |thread, cx| {
3354 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3355 thread.send_to_model(model.clone(), None, cx);
3356 });
3357
3358 let fake_model = model.as_fake();
3359 simulate_successful_response(&fake_model, cx);
3360
3361 // Should start generating summary when there are >= 2 messages
3362 thread.read_with(cx, |thread, _| {
3363 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3364 });
3365
3366 // Should not be able to set the summary while generating
3367 thread.update(cx, |thread, cx| {
3368 thread.set_summary("This should not work either", cx);
3369 });
3370
3371 thread.read_with(cx, |thread, _| {
3372 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3373 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3374 });
3375
3376 cx.run_until_parked();
3377 fake_model.stream_last_completion_response("Brief".into());
3378 fake_model.stream_last_completion_response(" Introduction".into());
3379 fake_model.end_last_completion_stream();
3380 cx.run_until_parked();
3381
3382 // Summary should be set
3383 thread.read_with(cx, |thread, _| {
3384 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3385 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3386 });
3387
3388 // Now we should be able to set a summary
3389 thread.update(cx, |thread, cx| {
3390 thread.set_summary("Brief Intro", cx);
3391 });
3392
3393 thread.read_with(cx, |thread, _| {
3394 assert_eq!(thread.summary().or_default(), "Brief Intro");
3395 });
3396
3397 // Test setting an empty summary (should default to DEFAULT)
3398 thread.update(cx, |thread, cx| {
3399 thread.set_summary("", cx);
3400 });
3401
3402 thread.read_with(cx, |thread, _| {
3403 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3404 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3405 });
3406 }
3407
3408 #[gpui::test]
3409 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3410 init_test_settings(cx);
3411
3412 let project = create_test_project(cx, json!({})).await;
3413
3414 let (_, _thread_store, thread, _context_store, model) =
3415 setup_test_environment(cx, project.clone()).await;
3416
3417 test_summarize_error(&model, &thread, cx);
3418
3419 // Now we should be able to set a summary
3420 thread.update(cx, |thread, cx| {
3421 thread.set_summary("Brief Intro", cx);
3422 });
3423
3424 thread.read_with(cx, |thread, _| {
3425 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3426 assert_eq!(thread.summary().or_default(), "Brief Intro");
3427 });
3428 }
3429
3430 #[gpui::test]
3431 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3432 init_test_settings(cx);
3433
3434 let project = create_test_project(cx, json!({})).await;
3435
3436 let (_, _thread_store, thread, _context_store, model) =
3437 setup_test_environment(cx, project.clone()).await;
3438
3439 test_summarize_error(&model, &thread, cx);
3440
3441 // Sending another message should not trigger another summarize request
3442 thread.update(cx, |thread, cx| {
3443 thread.insert_user_message(
3444 "How are you?",
3445 ContextLoadResult::default(),
3446 None,
3447 vec![],
3448 cx,
3449 );
3450 thread.send_to_model(model.clone(), None, cx);
3451 });
3452
3453 let fake_model = model.as_fake();
3454 simulate_successful_response(&fake_model, cx);
3455
3456 thread.read_with(cx, |thread, _| {
3457 // State is still Error, not Generating
3458 assert!(matches!(thread.summary(), ThreadSummary::Error));
3459 });
3460
3461 // But the summarize request can be invoked manually
3462 thread.update(cx, |thread, cx| {
3463 thread.summarize(cx);
3464 });
3465
3466 thread.read_with(cx, |thread, _| {
3467 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3468 });
3469
3470 cx.run_until_parked();
3471 fake_model.stream_last_completion_response("A successful summary".into());
3472 fake_model.end_last_completion_stream();
3473 cx.run_until_parked();
3474
3475 thread.read_with(cx, |thread, _| {
3476 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3477 assert_eq!(thread.summary().or_default(), "A successful summary");
3478 });
3479 }
3480
3481 fn test_summarize_error(
3482 model: &Arc<dyn LanguageModel>,
3483 thread: &Entity<Thread>,
3484 cx: &mut TestAppContext,
3485 ) {
3486 thread.update(cx, |thread, cx| {
3487 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3488 thread.send_to_model(model.clone(), None, cx);
3489 });
3490
3491 let fake_model = model.as_fake();
3492 simulate_successful_response(&fake_model, cx);
3493
3494 thread.read_with(cx, |thread, _| {
3495 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3496 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3497 });
3498
3499 // Simulate summary request ending
3500 cx.run_until_parked();
3501 fake_model.end_last_completion_stream();
3502 cx.run_until_parked();
3503
3504 // State is set to Error and default message
3505 thread.read_with(cx, |thread, _| {
3506 assert!(matches!(thread.summary(), ThreadSummary::Error));
3507 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3508 });
3509 }
3510
3511 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3512 cx.run_until_parked();
3513 fake_model.stream_last_completion_response("Assistant response".into());
3514 fake_model.end_last_completion_stream();
3515 cx.run_until_parked();
3516 }
3517
3518 fn init_test_settings(cx: &mut TestAppContext) {
3519 cx.update(|cx| {
3520 let settings_store = SettingsStore::test(cx);
3521 cx.set_global(settings_store);
3522 language::init(cx);
3523 Project::init_settings(cx);
3524 AssistantSettings::register(cx);
3525 prompt_store::init(cx);
3526 thread_store::init(cx);
3527 workspace::init_settings(cx);
3528 language_model::init_settings(cx);
3529 ThemeSettings::register(cx);
3530 EditorSettings::register(cx);
3531 ToolRegistry::default_global(cx);
3532 });
3533 }
3534
3535 // Helper to create a test project with test files
3536 async fn create_test_project(
3537 cx: &mut TestAppContext,
3538 files: serde_json::Value,
3539 ) -> Entity<Project> {
3540 let fs = FakeFs::new(cx.executor());
3541 fs.insert_tree(path!("/test"), files).await;
3542 Project::test(fs, [path!("/test").as_ref()], cx).await
3543 }
3544
3545 async fn setup_test_environment(
3546 cx: &mut TestAppContext,
3547 project: Entity<Project>,
3548 ) -> (
3549 Entity<Workspace>,
3550 Entity<ThreadStore>,
3551 Entity<Thread>,
3552 Entity<ContextStore>,
3553 Arc<dyn LanguageModel>,
3554 ) {
3555 let (workspace, cx) =
3556 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3557
3558 let thread_store = cx
3559 .update(|_, cx| {
3560 ThreadStore::load(
3561 project.clone(),
3562 cx.new(|_| ToolWorkingSet::default()),
3563 None,
3564 Arc::new(PromptBuilder::new(None).unwrap()),
3565 cx,
3566 )
3567 })
3568 .await
3569 .unwrap();
3570
3571 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3572 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3573
3574 let provider = Arc::new(FakeLanguageModelProvider);
3575 let model = provider.test_model();
3576 let model: Arc<dyn LanguageModel> = Arc::new(model);
3577
3578 cx.update(|_, cx| {
3579 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3580 registry.set_default_model(
3581 Some(ConfiguredModel {
3582 provider: provider.clone(),
3583 model: model.clone(),
3584 }),
3585 cx,
3586 );
3587 registry.set_thread_summary_model(
3588 Some(ConfiguredModel {
3589 provider,
3590 model: model.clone(),
3591 }),
3592 cx,
3593 );
3594 })
3595 });
3596
3597 (workspace, thread_store, thread, context_store, model)
3598 }
3599
3600 async fn add_file_to_context(
3601 project: &Entity<Project>,
3602 context_store: &Entity<ContextStore>,
3603 path: &str,
3604 cx: &mut TestAppContext,
3605 ) -> Result<Entity<language::Buffer>> {
3606 let buffer_path = project
3607 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3608 .unwrap();
3609
3610 let buffer = project
3611 .update(cx, |project, cx| {
3612 project.open_buffer(buffer_path.clone(), cx)
3613 })
3614 .await
3615 .unwrap();
3616
3617 context_store.update(cx, |context_store, cx| {
3618 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3619 });
3620
3621 Ok(buffer)
3622 }
3623}