1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(text)
885 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
886 Some(text)
887 }
888 LanguageModelToolResultContent::Image(_) => {
889 // TODO: We should display image
890 None
891 }
892 }
893 }
894
895 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
896 self.tool_use.tool_result_card(id).cloned()
897 }
898
899 /// Return tools that are both enabled and supported by the model
900 pub fn available_tools(
901 &self,
902 cx: &App,
903 model: Arc<dyn LanguageModel>,
904 ) -> Vec<LanguageModelRequestTool> {
905 if model.supports_tools() {
906 self.tools()
907 .read(cx)
908 .enabled_tools(cx)
909 .into_iter()
910 .filter_map(|tool| {
911 // Skip tools that cannot be supported
912 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
913 Some(LanguageModelRequestTool {
914 name: tool.name(),
915 description: tool.description(),
916 input_schema,
917 })
918 })
919 .collect()
920 } else {
921 Vec::default()
922 }
923 }
924
925 pub fn insert_user_message(
926 &mut self,
927 text: impl Into<String>,
928 loaded_context: ContextLoadResult,
929 git_checkpoint: Option<GitStoreCheckpoint>,
930 creases: Vec<MessageCrease>,
931 cx: &mut Context<Self>,
932 ) -> MessageId {
933 if !loaded_context.referenced_buffers.is_empty() {
934 self.action_log.update(cx, |log, cx| {
935 for buffer in loaded_context.referenced_buffers {
936 log.buffer_read(buffer, cx);
937 }
938 });
939 }
940
941 let message_id = self.insert_message(
942 Role::User,
943 vec![MessageSegment::Text(text.into())],
944 loaded_context.loaded_context,
945 creases,
946 cx,
947 );
948
949 if let Some(git_checkpoint) = git_checkpoint {
950 self.pending_checkpoint = Some(ThreadCheckpoint {
951 message_id,
952 git_checkpoint,
953 });
954 }
955
956 self.auto_capture_telemetry(cx);
957
958 message_id
959 }
960
961 pub fn insert_assistant_message(
962 &mut self,
963 segments: Vec<MessageSegment>,
964 cx: &mut Context<Self>,
965 ) -> MessageId {
966 self.insert_message(
967 Role::Assistant,
968 segments,
969 LoadedContext::default(),
970 Vec::new(),
971 cx,
972 )
973 }
974
975 pub fn insert_message(
976 &mut self,
977 role: Role,
978 segments: Vec<MessageSegment>,
979 loaded_context: LoadedContext,
980 creases: Vec<MessageCrease>,
981 cx: &mut Context<Self>,
982 ) -> MessageId {
983 let id = self.next_message_id.post_inc();
984 self.messages.push(Message {
985 id,
986 role,
987 segments,
988 loaded_context,
989 creases,
990 });
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageAdded(id));
993 id
994 }
995
996 pub fn edit_message(
997 &mut self,
998 id: MessageId,
999 new_role: Role,
1000 new_segments: Vec<MessageSegment>,
1001 loaded_context: Option<LoadedContext>,
1002 checkpoint: Option<GitStoreCheckpoint>,
1003 cx: &mut Context<Self>,
1004 ) -> bool {
1005 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1006 return false;
1007 };
1008 message.role = new_role;
1009 message.segments = new_segments;
1010 if let Some(context) = loaded_context {
1011 message.loaded_context = context;
1012 }
1013 if let Some(git_checkpoint) = checkpoint {
1014 self.checkpoints_by_message.insert(
1015 id,
1016 ThreadCheckpoint {
1017 message_id: id,
1018 git_checkpoint,
1019 },
1020 );
1021 }
1022 self.touch_updated_at();
1023 cx.emit(ThreadEvent::MessageEdited(id));
1024 true
1025 }
1026
1027 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1028 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1029 return false;
1030 };
1031 self.messages.remove(index);
1032 self.touch_updated_at();
1033 cx.emit(ThreadEvent::MessageDeleted(id));
1034 true
1035 }
1036
1037 /// Returns the representation of this [`Thread`] in a textual form.
1038 ///
1039 /// This is the representation we use when attaching a thread as context to another thread.
1040 pub fn text(&self) -> String {
1041 let mut text = String::new();
1042
1043 for message in &self.messages {
1044 text.push_str(match message.role {
1045 language_model::Role::User => "User:",
1046 language_model::Role::Assistant => "Agent:",
1047 language_model::Role::System => "System:",
1048 });
1049 text.push('\n');
1050
1051 for segment in &message.segments {
1052 match segment {
1053 MessageSegment::Text(content) => text.push_str(content),
1054 MessageSegment::Thinking { text: content, .. } => {
1055 text.push_str(&format!("<think>{}</think>", content))
1056 }
1057 MessageSegment::RedactedThinking(_) => {}
1058 }
1059 }
1060 text.push('\n');
1061 }
1062
1063 text
1064 }
1065
1066 /// Serializes this thread into a format for storage or telemetry.
1067 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1068 let initial_project_snapshot = self.initial_project_snapshot.clone();
1069 cx.spawn(async move |this, cx| {
1070 let initial_project_snapshot = initial_project_snapshot.await;
1071 this.read_with(cx, |this, cx| SerializedThread {
1072 version: SerializedThread::VERSION.to_string(),
1073 summary: this.summary().or_default(),
1074 updated_at: this.updated_at(),
1075 messages: this
1076 .messages()
1077 .map(|message| SerializedMessage {
1078 id: message.id,
1079 role: message.role,
1080 segments: message
1081 .segments
1082 .iter()
1083 .map(|segment| match segment {
1084 MessageSegment::Text(text) => {
1085 SerializedMessageSegment::Text { text: text.clone() }
1086 }
1087 MessageSegment::Thinking { text, signature } => {
1088 SerializedMessageSegment::Thinking {
1089 text: text.clone(),
1090 signature: signature.clone(),
1091 }
1092 }
1093 MessageSegment::RedactedThinking(data) => {
1094 SerializedMessageSegment::RedactedThinking {
1095 data: data.clone(),
1096 }
1097 }
1098 })
1099 .collect(),
1100 tool_uses: this
1101 .tool_uses_for_message(message.id, cx)
1102 .into_iter()
1103 .map(|tool_use| SerializedToolUse {
1104 id: tool_use.id,
1105 name: tool_use.name,
1106 input: tool_use.input,
1107 })
1108 .collect(),
1109 tool_results: this
1110 .tool_results_for_message(message.id)
1111 .into_iter()
1112 .map(|tool_result| SerializedToolResult {
1113 tool_use_id: tool_result.tool_use_id.clone(),
1114 is_error: tool_result.is_error,
1115 content: tool_result.content.clone(),
1116 output: tool_result.output.clone(),
1117 })
1118 .collect(),
1119 context: message.loaded_context.text.clone(),
1120 creases: message
1121 .creases
1122 .iter()
1123 .map(|crease| SerializedCrease {
1124 start: crease.range.start,
1125 end: crease.range.end,
1126 icon_path: crease.metadata.icon_path.clone(),
1127 label: crease.metadata.label.clone(),
1128 })
1129 .collect(),
1130 })
1131 .collect(),
1132 initial_project_snapshot,
1133 cumulative_token_usage: this.cumulative_token_usage,
1134 request_token_usage: this.request_token_usage.clone(),
1135 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1136 exceeded_window_error: this.exceeded_window_error.clone(),
1137 model: this
1138 .configured_model
1139 .as_ref()
1140 .map(|model| SerializedLanguageModel {
1141 provider: model.provider.id().0.to_string(),
1142 model: model.model.id().0.to_string(),
1143 }),
1144 completion_mode: Some(this.completion_mode),
1145 })
1146 })
1147 }
1148
1149 pub fn remaining_turns(&self) -> u32 {
1150 self.remaining_turns
1151 }
1152
1153 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1154 self.remaining_turns = remaining_turns;
1155 }
1156
1157 pub fn send_to_model(
1158 &mut self,
1159 model: Arc<dyn LanguageModel>,
1160 window: Option<AnyWindowHandle>,
1161 cx: &mut Context<Self>,
1162 ) {
1163 if self.remaining_turns == 0 {
1164 return;
1165 }
1166
1167 self.remaining_turns -= 1;
1168
1169 let request = self.to_completion_request(model.clone(), cx);
1170
1171 self.stream_completion(request, model, window, cx);
1172 }
1173
1174 pub fn used_tools_since_last_user_message(&self) -> bool {
1175 for message in self.messages.iter().rev() {
1176 if self.tool_use.message_has_tool_results(message.id) {
1177 return true;
1178 } else if message.role == Role::User {
1179 return false;
1180 }
1181 }
1182
1183 false
1184 }
1185
1186 pub fn to_completion_request(
1187 &self,
1188 model: Arc<dyn LanguageModel>,
1189 cx: &mut Context<Self>,
1190 ) -> LanguageModelRequest {
1191 let mut request = LanguageModelRequest {
1192 thread_id: Some(self.id.to_string()),
1193 prompt_id: Some(self.last_prompt_id.to_string()),
1194 mode: None,
1195 messages: vec![],
1196 tools: Vec::new(),
1197 tool_choice: None,
1198 stop: Vec::new(),
1199 temperature: AssistantSettings::temperature_for_model(&model, cx),
1200 };
1201
1202 let available_tools = self.available_tools(cx, model.clone());
1203 let available_tool_names = available_tools
1204 .iter()
1205 .map(|tool| tool.name.clone())
1206 .collect();
1207
1208 let model_context = &ModelContext {
1209 available_tools: available_tool_names,
1210 };
1211
1212 if let Some(project_context) = self.project_context.borrow().as_ref() {
1213 match self
1214 .prompt_builder
1215 .generate_assistant_system_prompt(project_context, model_context)
1216 {
1217 Err(err) => {
1218 let message = format!("{err:?}").into();
1219 log::error!("{message}");
1220 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1221 header: "Error generating system prompt".into(),
1222 message,
1223 }));
1224 }
1225 Ok(system_prompt) => {
1226 request.messages.push(LanguageModelRequestMessage {
1227 role: Role::System,
1228 content: vec![MessageContent::Text(system_prompt)],
1229 cache: true,
1230 });
1231 }
1232 }
1233 } else {
1234 let message = "Context for system prompt unexpectedly not ready.".into();
1235 log::error!("{message}");
1236 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1237 header: "Error generating system prompt".into(),
1238 message,
1239 }));
1240 }
1241
1242 let mut message_ix_to_cache = None;
1243 for message in &self.messages {
1244 let mut request_message = LanguageModelRequestMessage {
1245 role: message.role,
1246 content: Vec::new(),
1247 cache: false,
1248 };
1249
1250 message
1251 .loaded_context
1252 .add_to_request_message(&mut request_message);
1253
1254 for segment in &message.segments {
1255 match segment {
1256 MessageSegment::Text(text) => {
1257 if !text.is_empty() {
1258 request_message
1259 .content
1260 .push(MessageContent::Text(text.into()));
1261 }
1262 }
1263 MessageSegment::Thinking { text, signature } => {
1264 if !text.is_empty() {
1265 request_message.content.push(MessageContent::Thinking {
1266 text: text.into(),
1267 signature: signature.clone(),
1268 });
1269 }
1270 }
1271 MessageSegment::RedactedThinking(data) => {
1272 request_message
1273 .content
1274 .push(MessageContent::RedactedThinking(data.clone()));
1275 }
1276 };
1277 }
1278
1279 let mut cache_message = true;
1280 let mut tool_results_message = LanguageModelRequestMessage {
1281 role: Role::User,
1282 content: Vec::new(),
1283 cache: false,
1284 };
1285 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1286 if let Some(tool_result) = tool_result {
1287 request_message
1288 .content
1289 .push(MessageContent::ToolUse(tool_use.clone()));
1290 tool_results_message
1291 .content
1292 .push(MessageContent::ToolResult(LanguageModelToolResult {
1293 tool_use_id: tool_use.id.clone(),
1294 tool_name: tool_result.tool_name.clone(),
1295 is_error: tool_result.is_error,
1296 content: if tool_result.content.is_empty() {
1297 // Surprisingly, the API fails if we return an empty string here.
1298 // It thinks we are sending a tool use without a tool result.
1299 "<Tool returned an empty string>".into()
1300 } else {
1301 tool_result.content.clone()
1302 },
1303 output: None,
1304 }));
1305 } else {
1306 cache_message = false;
1307 log::debug!(
1308 "skipped tool use {:?} because it is still pending",
1309 tool_use
1310 );
1311 }
1312 }
1313
1314 if cache_message {
1315 message_ix_to_cache = Some(request.messages.len());
1316 }
1317 request.messages.push(request_message);
1318
1319 if !tool_results_message.content.is_empty() {
1320 if cache_message {
1321 message_ix_to_cache = Some(request.messages.len());
1322 }
1323 request.messages.push(tool_results_message);
1324 }
1325 }
1326
1327 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1328 if let Some(message_ix_to_cache) = message_ix_to_cache {
1329 request.messages[message_ix_to_cache].cache = true;
1330 }
1331
1332 self.attached_tracked_files_state(&mut request.messages, cx);
1333
1334 request.tools = available_tools;
1335 request.mode = if model.supports_max_mode() {
1336 Some(self.completion_mode.into())
1337 } else {
1338 Some(CompletionMode::Normal.into())
1339 };
1340
1341 request
1342 }
1343
1344 fn to_summarize_request(
1345 &self,
1346 model: &Arc<dyn LanguageModel>,
1347 added_user_message: String,
1348 cx: &App,
1349 ) -> LanguageModelRequest {
1350 let mut request = LanguageModelRequest {
1351 thread_id: None,
1352 prompt_id: None,
1353 mode: None,
1354 messages: vec![],
1355 tools: Vec::new(),
1356 tool_choice: None,
1357 stop: Vec::new(),
1358 temperature: AssistantSettings::temperature_for_model(model, cx),
1359 };
1360
1361 for message in &self.messages {
1362 let mut request_message = LanguageModelRequestMessage {
1363 role: message.role,
1364 content: Vec::new(),
1365 cache: false,
1366 };
1367
1368 for segment in &message.segments {
1369 match segment {
1370 MessageSegment::Text(text) => request_message
1371 .content
1372 .push(MessageContent::Text(text.clone())),
1373 MessageSegment::Thinking { .. } => {}
1374 MessageSegment::RedactedThinking(_) => {}
1375 }
1376 }
1377
1378 if request_message.content.is_empty() {
1379 continue;
1380 }
1381
1382 request.messages.push(request_message);
1383 }
1384
1385 request.messages.push(LanguageModelRequestMessage {
1386 role: Role::User,
1387 content: vec![MessageContent::Text(added_user_message)],
1388 cache: false,
1389 });
1390
1391 request
1392 }
1393
1394 fn attached_tracked_files_state(
1395 &self,
1396 messages: &mut Vec<LanguageModelRequestMessage>,
1397 cx: &App,
1398 ) {
1399 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1400
1401 let mut stale_message = String::new();
1402
1403 let action_log = self.action_log.read(cx);
1404
1405 for stale_file in action_log.stale_buffers(cx) {
1406 let Some(file) = stale_file.read(cx).file() else {
1407 continue;
1408 };
1409
1410 if stale_message.is_empty() {
1411 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1412 }
1413
1414 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1415 }
1416
1417 let mut content = Vec::with_capacity(2);
1418
1419 if !stale_message.is_empty() {
1420 content.push(stale_message.into());
1421 }
1422
1423 if !content.is_empty() {
1424 let context_message = LanguageModelRequestMessage {
1425 role: Role::User,
1426 content,
1427 cache: false,
1428 };
1429
1430 messages.push(context_message);
1431 }
1432 }
1433
1434 pub fn stream_completion(
1435 &mut self,
1436 request: LanguageModelRequest,
1437 model: Arc<dyn LanguageModel>,
1438 window: Option<AnyWindowHandle>,
1439 cx: &mut Context<Self>,
1440 ) {
1441 self.tool_use_limit_reached = false;
1442
1443 let pending_completion_id = post_inc(&mut self.completion_count);
1444 let mut request_callback_parameters = if self.request_callback.is_some() {
1445 Some((request.clone(), Vec::new()))
1446 } else {
1447 None
1448 };
1449 let prompt_id = self.last_prompt_id.clone();
1450 let tool_use_metadata = ToolUseMetadata {
1451 model: model.clone(),
1452 thread_id: self.id.clone(),
1453 prompt_id: prompt_id.clone(),
1454 };
1455
1456 self.last_received_chunk_at = Some(Instant::now());
1457
1458 let task = cx.spawn(async move |thread, cx| {
1459 let stream_completion_future = model.stream_completion(request, &cx);
1460 let initial_token_usage =
1461 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1462 let stream_completion = async {
1463 let mut events = stream_completion_future.await?;
1464
1465 let mut stop_reason = StopReason::EndTurn;
1466 let mut current_token_usage = TokenUsage::default();
1467
1468 thread
1469 .update(cx, |_thread, cx| {
1470 cx.emit(ThreadEvent::NewRequest);
1471 })
1472 .ok();
1473
1474 let mut request_assistant_message_id = None;
1475
1476 while let Some(event) = events.next().await {
1477 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1478 response_events
1479 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1480 }
1481
1482 thread.update(cx, |thread, cx| {
1483 let event = match event {
1484 Ok(event) => event,
1485 Err(LanguageModelCompletionError::BadInputJson {
1486 id,
1487 tool_name,
1488 raw_input: invalid_input_json,
1489 json_parse_error,
1490 }) => {
1491 thread.receive_invalid_tool_json(
1492 id,
1493 tool_name,
1494 invalid_input_json,
1495 json_parse_error,
1496 window,
1497 cx,
1498 );
1499 return Ok(());
1500 }
1501 Err(LanguageModelCompletionError::Other(error)) => {
1502 return Err(error);
1503 }
1504 };
1505
1506 match event {
1507 LanguageModelCompletionEvent::StartMessage { .. } => {
1508 request_assistant_message_id =
1509 Some(thread.insert_assistant_message(
1510 vec![MessageSegment::Text(String::new())],
1511 cx,
1512 ));
1513 }
1514 LanguageModelCompletionEvent::Stop(reason) => {
1515 stop_reason = reason;
1516 }
1517 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1518 thread.update_token_usage_at_last_message(token_usage);
1519 thread.cumulative_token_usage = thread.cumulative_token_usage
1520 + token_usage
1521 - current_token_usage;
1522 current_token_usage = token_usage;
1523 }
1524 LanguageModelCompletionEvent::Text(chunk) => {
1525 thread.received_chunk();
1526
1527 cx.emit(ThreadEvent::ReceivedTextChunk);
1528 if let Some(last_message) = thread.messages.last_mut() {
1529 if last_message.role == Role::Assistant
1530 && !thread.tool_use.has_tool_results(last_message.id)
1531 {
1532 last_message.push_text(&chunk);
1533 cx.emit(ThreadEvent::StreamedAssistantText(
1534 last_message.id,
1535 chunk,
1536 ));
1537 } else {
1538 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1539 // of a new Assistant response.
1540 //
1541 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1542 // will result in duplicating the text of the chunk in the rendered Markdown.
1543 request_assistant_message_id =
1544 Some(thread.insert_assistant_message(
1545 vec![MessageSegment::Text(chunk.to_string())],
1546 cx,
1547 ));
1548 };
1549 }
1550 }
1551 LanguageModelCompletionEvent::Thinking {
1552 text: chunk,
1553 signature,
1554 } => {
1555 thread.received_chunk();
1556
1557 if let Some(last_message) = thread.messages.last_mut() {
1558 if last_message.role == Role::Assistant
1559 && !thread.tool_use.has_tool_results(last_message.id)
1560 {
1561 last_message.push_thinking(&chunk, signature);
1562 cx.emit(ThreadEvent::StreamedAssistantThinking(
1563 last_message.id,
1564 chunk,
1565 ));
1566 } else {
1567 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1568 // of a new Assistant response.
1569 //
1570 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1571 // will result in duplicating the text of the chunk in the rendered Markdown.
1572 request_assistant_message_id =
1573 Some(thread.insert_assistant_message(
1574 vec![MessageSegment::Thinking {
1575 text: chunk.to_string(),
1576 signature,
1577 }],
1578 cx,
1579 ));
1580 };
1581 }
1582 }
1583 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1584 let last_assistant_message_id = request_assistant_message_id
1585 .unwrap_or_else(|| {
1586 let new_assistant_message_id =
1587 thread.insert_assistant_message(vec![], cx);
1588 request_assistant_message_id =
1589 Some(new_assistant_message_id);
1590 new_assistant_message_id
1591 });
1592
1593 let tool_use_id = tool_use.id.clone();
1594 let streamed_input = if tool_use.is_input_complete {
1595 None
1596 } else {
1597 Some((&tool_use.input).clone())
1598 };
1599
1600 let ui_text = thread.tool_use.request_tool_use(
1601 last_assistant_message_id,
1602 tool_use,
1603 tool_use_metadata.clone(),
1604 cx,
1605 );
1606
1607 if let Some(input) = streamed_input {
1608 cx.emit(ThreadEvent::StreamedToolUse {
1609 tool_use_id,
1610 ui_text,
1611 input,
1612 });
1613 }
1614 }
1615 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1616 if let Some(completion) = thread
1617 .pending_completions
1618 .iter_mut()
1619 .find(|completion| completion.id == pending_completion_id)
1620 {
1621 match status_update {
1622 CompletionRequestStatus::Queued {
1623 position,
1624 } => {
1625 completion.queue_state = QueueState::Queued { position };
1626 }
1627 CompletionRequestStatus::Started => {
1628 completion.queue_state = QueueState::Started;
1629 }
1630 CompletionRequestStatus::Failed {
1631 code, message, request_id
1632 } => {
1633 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1634 }
1635 CompletionRequestStatus::UsageUpdated {
1636 amount, limit
1637 } => {
1638 let usage = RequestUsage { limit, amount: amount as i32 };
1639
1640 thread.last_usage = Some(usage);
1641 }
1642 CompletionRequestStatus::ToolUseLimitReached => {
1643 thread.tool_use_limit_reached = true;
1644 }
1645 }
1646 }
1647 }
1648 }
1649
1650 thread.touch_updated_at();
1651 cx.emit(ThreadEvent::StreamedCompletion);
1652 cx.notify();
1653
1654 thread.auto_capture_telemetry(cx);
1655 Ok(())
1656 })??;
1657
1658 smol::future::yield_now().await;
1659 }
1660
1661 thread.update(cx, |thread, cx| {
1662 thread.last_received_chunk_at = None;
1663 thread
1664 .pending_completions
1665 .retain(|completion| completion.id != pending_completion_id);
1666
1667 // If there is a response without tool use, summarize the message. Otherwise,
1668 // allow two tool uses before summarizing.
1669 if matches!(thread.summary, ThreadSummary::Pending)
1670 && thread.messages.len() >= 2
1671 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1672 {
1673 thread.summarize(cx);
1674 }
1675 })?;
1676
1677 anyhow::Ok(stop_reason)
1678 };
1679
1680 let result = stream_completion.await;
1681
1682 thread
1683 .update(cx, |thread, cx| {
1684 thread.finalize_pending_checkpoint(cx);
1685 match result.as_ref() {
1686 Ok(stop_reason) => match stop_reason {
1687 StopReason::ToolUse => {
1688 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1689 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1690 }
1691 StopReason::EndTurn | StopReason::MaxTokens => {
1692 thread.project.update(cx, |project, cx| {
1693 project.set_agent_location(None, cx);
1694 });
1695 }
1696 StopReason::Refusal => {
1697 thread.project.update(cx, |project, cx| {
1698 project.set_agent_location(None, cx);
1699 });
1700
1701 // Remove the turn that was refused.
1702 //
1703 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1704 {
1705 let mut messages_to_remove = Vec::new();
1706
1707 for (ix, message) in thread.messages.iter().enumerate().rev() {
1708 messages_to_remove.push(message.id);
1709
1710 if message.role == Role::User {
1711 if ix == 0 {
1712 break;
1713 }
1714
1715 if let Some(prev_message) = thread.messages.get(ix - 1) {
1716 if prev_message.role == Role::Assistant {
1717 break;
1718 }
1719 }
1720 }
1721 }
1722
1723 for message_id in messages_to_remove {
1724 thread.delete_message(message_id, cx);
1725 }
1726 }
1727
1728 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1729 header: "Language model refusal".into(),
1730 message: "Model refused to generate content for safety reasons.".into(),
1731 }));
1732 }
1733 },
1734 Err(error) => {
1735 thread.project.update(cx, |project, cx| {
1736 project.set_agent_location(None, cx);
1737 });
1738
1739 if error.is::<PaymentRequiredError>() {
1740 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1741 } else if let Some(error) =
1742 error.downcast_ref::<ModelRequestLimitReachedError>()
1743 {
1744 cx.emit(ThreadEvent::ShowError(
1745 ThreadError::ModelRequestLimitReached { plan: error.plan },
1746 ));
1747 } else if let Some(known_error) =
1748 error.downcast_ref::<LanguageModelKnownError>()
1749 {
1750 match known_error {
1751 LanguageModelKnownError::ContextWindowLimitExceeded {
1752 tokens,
1753 } => {
1754 thread.exceeded_window_error = Some(ExceededWindowError {
1755 model_id: model.id(),
1756 token_count: *tokens,
1757 });
1758 cx.notify();
1759 }
1760 }
1761 } else {
1762 let error_message = error
1763 .chain()
1764 .map(|err| err.to_string())
1765 .collect::<Vec<_>>()
1766 .join("\n");
1767 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1768 header: "Error interacting with language model".into(),
1769 message: SharedString::from(error_message.clone()),
1770 }));
1771 }
1772
1773 thread.cancel_last_completion(window, cx);
1774 }
1775 }
1776 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1777
1778 if let Some((request_callback, (request, response_events))) = thread
1779 .request_callback
1780 .as_mut()
1781 .zip(request_callback_parameters.as_ref())
1782 {
1783 request_callback(request, response_events);
1784 }
1785
1786 thread.auto_capture_telemetry(cx);
1787
1788 if let Ok(initial_usage) = initial_token_usage {
1789 let usage = thread.cumulative_token_usage - initial_usage;
1790
1791 telemetry::event!(
1792 "Assistant Thread Completion",
1793 thread_id = thread.id().to_string(),
1794 prompt_id = prompt_id,
1795 model = model.telemetry_id(),
1796 model_provider = model.provider_id().to_string(),
1797 input_tokens = usage.input_tokens,
1798 output_tokens = usage.output_tokens,
1799 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1800 cache_read_input_tokens = usage.cache_read_input_tokens,
1801 );
1802 }
1803 })
1804 .ok();
1805 });
1806
1807 self.pending_completions.push(PendingCompletion {
1808 id: pending_completion_id,
1809 queue_state: QueueState::Sending,
1810 _task: task,
1811 });
1812 }
1813
1814 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1815 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1816 println!("No thread summary model");
1817 return;
1818 };
1819
1820 if !model.provider.is_authenticated(cx) {
1821 return;
1822 }
1823
1824 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1825 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1826 If the conversation is about a specific subject, include it in the title. \
1827 Be descriptive. DO NOT speak in the first person.";
1828
1829 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1830
1831 self.summary = ThreadSummary::Generating;
1832
1833 self.pending_summary = cx.spawn(async move |this, cx| {
1834 let result = async {
1835 let mut messages = model.model.stream_completion(request, &cx).await?;
1836
1837 let mut new_summary = String::new();
1838 while let Some(event) = messages.next().await {
1839 let Ok(event) = event else {
1840 continue;
1841 };
1842 let text = match event {
1843 LanguageModelCompletionEvent::Text(text) => text,
1844 LanguageModelCompletionEvent::StatusUpdate(
1845 CompletionRequestStatus::UsageUpdated { amount, limit },
1846 ) => {
1847 this.update(cx, |thread, _cx| {
1848 thread.last_usage = Some(RequestUsage {
1849 limit,
1850 amount: amount as i32,
1851 });
1852 })?;
1853 continue;
1854 }
1855 _ => continue,
1856 };
1857
1858 let mut lines = text.lines();
1859 new_summary.extend(lines.next());
1860
1861 // Stop if the LLM generated multiple lines.
1862 if lines.next().is_some() {
1863 break;
1864 }
1865 }
1866
1867 anyhow::Ok(new_summary)
1868 }
1869 .await;
1870
1871 this.update(cx, |this, cx| {
1872 match result {
1873 Ok(new_summary) => {
1874 if new_summary.is_empty() {
1875 this.summary = ThreadSummary::Error;
1876 } else {
1877 this.summary = ThreadSummary::Ready(new_summary.into());
1878 }
1879 }
1880 Err(err) => {
1881 this.summary = ThreadSummary::Error;
1882 log::error!("Failed to generate thread summary: {}", err);
1883 }
1884 }
1885 cx.emit(ThreadEvent::SummaryGenerated);
1886 })
1887 .log_err()?;
1888
1889 Some(())
1890 });
1891 }
1892
1893 pub fn start_generating_detailed_summary_if_needed(
1894 &mut self,
1895 thread_store: WeakEntity<ThreadStore>,
1896 cx: &mut Context<Self>,
1897 ) {
1898 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1899 return;
1900 };
1901
1902 match &*self.detailed_summary_rx.borrow() {
1903 DetailedSummaryState::Generating { message_id, .. }
1904 | DetailedSummaryState::Generated { message_id, .. }
1905 if *message_id == last_message_id =>
1906 {
1907 // Already up-to-date
1908 return;
1909 }
1910 _ => {}
1911 }
1912
1913 let Some(ConfiguredModel { model, provider }) =
1914 LanguageModelRegistry::read_global(cx).thread_summary_model()
1915 else {
1916 return;
1917 };
1918
1919 if !provider.is_authenticated(cx) {
1920 return;
1921 }
1922
1923 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1924 1. A brief overview of what was discussed\n\
1925 2. Key facts or information discovered\n\
1926 3. Outcomes or conclusions reached\n\
1927 4. Any action items or next steps if any\n\
1928 Format it in Markdown with headings and bullet points.";
1929
1930 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1931
1932 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1933 message_id: last_message_id,
1934 };
1935
1936 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1937 // be better to allow the old task to complete, but this would require logic for choosing
1938 // which result to prefer (the old task could complete after the new one, resulting in a
1939 // stale summary).
1940 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1941 let stream = model.stream_completion_text(request, &cx);
1942 let Some(mut messages) = stream.await.log_err() else {
1943 thread
1944 .update(cx, |thread, _cx| {
1945 *thread.detailed_summary_tx.borrow_mut() =
1946 DetailedSummaryState::NotGenerated;
1947 })
1948 .ok()?;
1949 return None;
1950 };
1951
1952 let mut new_detailed_summary = String::new();
1953
1954 while let Some(chunk) = messages.stream.next().await {
1955 if let Some(chunk) = chunk.log_err() {
1956 new_detailed_summary.push_str(&chunk);
1957 }
1958 }
1959
1960 thread
1961 .update(cx, |thread, _cx| {
1962 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1963 text: new_detailed_summary.into(),
1964 message_id: last_message_id,
1965 };
1966 })
1967 .ok()?;
1968
1969 // Save thread so its summary can be reused later
1970 if let Some(thread) = thread.upgrade() {
1971 if let Ok(Ok(save_task)) = cx.update(|cx| {
1972 thread_store
1973 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1974 }) {
1975 save_task.await.log_err();
1976 }
1977 }
1978
1979 Some(())
1980 });
1981 }
1982
1983 pub async fn wait_for_detailed_summary_or_text(
1984 this: &Entity<Self>,
1985 cx: &mut AsyncApp,
1986 ) -> Option<SharedString> {
1987 let mut detailed_summary_rx = this
1988 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1989 .ok()?;
1990 loop {
1991 match detailed_summary_rx.recv().await? {
1992 DetailedSummaryState::Generating { .. } => {}
1993 DetailedSummaryState::NotGenerated => {
1994 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1995 }
1996 DetailedSummaryState::Generated { text, .. } => return Some(text),
1997 }
1998 }
1999 }
2000
2001 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2002 self.detailed_summary_rx
2003 .borrow()
2004 .text()
2005 .unwrap_or_else(|| self.text().into())
2006 }
2007
2008 pub fn is_generating_detailed_summary(&self) -> bool {
2009 matches!(
2010 &*self.detailed_summary_rx.borrow(),
2011 DetailedSummaryState::Generating { .. }
2012 )
2013 }
2014
2015 pub fn use_pending_tools(
2016 &mut self,
2017 window: Option<AnyWindowHandle>,
2018 cx: &mut Context<Self>,
2019 model: Arc<dyn LanguageModel>,
2020 ) -> Vec<PendingToolUse> {
2021 self.auto_capture_telemetry(cx);
2022 let request = Arc::new(self.to_completion_request(model.clone(), cx));
2023 let pending_tool_uses = self
2024 .tool_use
2025 .pending_tool_uses()
2026 .into_iter()
2027 .filter(|tool_use| tool_use.status.is_idle())
2028 .cloned()
2029 .collect::<Vec<_>>();
2030
2031 for tool_use in pending_tool_uses.iter() {
2032 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2033 if tool.needs_confirmation(&tool_use.input, cx)
2034 && !AssistantSettings::get_global(cx).always_allow_tool_actions
2035 {
2036 self.tool_use.confirm_tool_use(
2037 tool_use.id.clone(),
2038 tool_use.ui_text.clone(),
2039 tool_use.input.clone(),
2040 request.clone(),
2041 tool,
2042 );
2043 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2044 } else {
2045 self.run_tool(
2046 tool_use.id.clone(),
2047 tool_use.ui_text.clone(),
2048 tool_use.input.clone(),
2049 request.clone(),
2050 tool,
2051 model.clone(),
2052 window,
2053 cx,
2054 );
2055 }
2056 } else {
2057 self.handle_hallucinated_tool_use(
2058 tool_use.id.clone(),
2059 tool_use.name.clone(),
2060 window,
2061 cx,
2062 );
2063 }
2064 }
2065
2066 pending_tool_uses
2067 }
2068
2069 pub fn handle_hallucinated_tool_use(
2070 &mut self,
2071 tool_use_id: LanguageModelToolUseId,
2072 hallucinated_tool_name: Arc<str>,
2073 window: Option<AnyWindowHandle>,
2074 cx: &mut Context<Thread>,
2075 ) {
2076 let available_tools = self.tools.read(cx).enabled_tools(cx);
2077
2078 let tool_list = available_tools
2079 .iter()
2080 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2081 .collect::<Vec<_>>()
2082 .join("\n");
2083
2084 let error_message = format!(
2085 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2086 hallucinated_tool_name, tool_list
2087 );
2088
2089 let pending_tool_use = self.tool_use.insert_tool_output(
2090 tool_use_id.clone(),
2091 hallucinated_tool_name,
2092 Err(anyhow!("Missing tool call: {error_message}")),
2093 self.configured_model.as_ref(),
2094 );
2095
2096 cx.emit(ThreadEvent::MissingToolUse {
2097 tool_use_id: tool_use_id.clone(),
2098 ui_text: error_message.into(),
2099 });
2100
2101 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2102 }
2103
2104 pub fn receive_invalid_tool_json(
2105 &mut self,
2106 tool_use_id: LanguageModelToolUseId,
2107 tool_name: Arc<str>,
2108 invalid_json: Arc<str>,
2109 error: String,
2110 window: Option<AnyWindowHandle>,
2111 cx: &mut Context<Thread>,
2112 ) {
2113 log::error!("The model returned invalid input JSON: {invalid_json}");
2114
2115 let pending_tool_use = self.tool_use.insert_tool_output(
2116 tool_use_id.clone(),
2117 tool_name,
2118 Err(anyhow!("Error parsing input JSON: {error}")),
2119 self.configured_model.as_ref(),
2120 );
2121 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2122 pending_tool_use.ui_text.clone()
2123 } else {
2124 log::error!(
2125 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2126 );
2127 format!("Unknown tool {}", tool_use_id).into()
2128 };
2129
2130 cx.emit(ThreadEvent::InvalidToolInput {
2131 tool_use_id: tool_use_id.clone(),
2132 ui_text,
2133 invalid_input_json: invalid_json,
2134 });
2135
2136 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2137 }
2138
2139 pub fn run_tool(
2140 &mut self,
2141 tool_use_id: LanguageModelToolUseId,
2142 ui_text: impl Into<SharedString>,
2143 input: serde_json::Value,
2144 request: Arc<LanguageModelRequest>,
2145 tool: Arc<dyn Tool>,
2146 model: Arc<dyn LanguageModel>,
2147 window: Option<AnyWindowHandle>,
2148 cx: &mut Context<Thread>,
2149 ) {
2150 let task =
2151 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2152 self.tool_use
2153 .run_pending_tool(tool_use_id, ui_text.into(), task);
2154 }
2155
2156 fn spawn_tool_use(
2157 &mut self,
2158 tool_use_id: LanguageModelToolUseId,
2159 request: Arc<LanguageModelRequest>,
2160 input: serde_json::Value,
2161 tool: Arc<dyn Tool>,
2162 model: Arc<dyn LanguageModel>,
2163 window: Option<AnyWindowHandle>,
2164 cx: &mut Context<Thread>,
2165 ) -> Task<()> {
2166 let tool_name: Arc<str> = tool.name().into();
2167
2168 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2169 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2170 } else {
2171 tool.run(
2172 input,
2173 request,
2174 self.project.clone(),
2175 self.action_log.clone(),
2176 model,
2177 window,
2178 cx,
2179 )
2180 };
2181
2182 // Store the card separately if it exists
2183 if let Some(card) = tool_result.card.clone() {
2184 self.tool_use
2185 .insert_tool_result_card(tool_use_id.clone(), card);
2186 }
2187
2188 cx.spawn({
2189 async move |thread: WeakEntity<Thread>, cx| {
2190 let output = tool_result.output.await;
2191
2192 thread
2193 .update(cx, |thread, cx| {
2194 let pending_tool_use = thread.tool_use.insert_tool_output(
2195 tool_use_id.clone(),
2196 tool_name,
2197 output,
2198 thread.configured_model.as_ref(),
2199 );
2200 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2201 })
2202 .ok();
2203 }
2204 })
2205 }
2206
2207 fn tool_finished(
2208 &mut self,
2209 tool_use_id: LanguageModelToolUseId,
2210 pending_tool_use: Option<PendingToolUse>,
2211 canceled: bool,
2212 window: Option<AnyWindowHandle>,
2213 cx: &mut Context<Self>,
2214 ) {
2215 if self.all_tools_finished() {
2216 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2217 if !canceled {
2218 self.send_to_model(model.clone(), window, cx);
2219 }
2220 self.auto_capture_telemetry(cx);
2221 }
2222 }
2223
2224 cx.emit(ThreadEvent::ToolFinished {
2225 tool_use_id,
2226 pending_tool_use,
2227 });
2228 }
2229
2230 /// Cancels the last pending completion, if there are any pending.
2231 ///
2232 /// Returns whether a completion was canceled.
2233 pub fn cancel_last_completion(
2234 &mut self,
2235 window: Option<AnyWindowHandle>,
2236 cx: &mut Context<Self>,
2237 ) -> bool {
2238 let mut canceled = self.pending_completions.pop().is_some();
2239
2240 for pending_tool_use in self.tool_use.cancel_pending() {
2241 canceled = true;
2242 self.tool_finished(
2243 pending_tool_use.id.clone(),
2244 Some(pending_tool_use),
2245 true,
2246 window,
2247 cx,
2248 );
2249 }
2250
2251 self.finalize_pending_checkpoint(cx);
2252
2253 if canceled {
2254 cx.emit(ThreadEvent::CompletionCanceled);
2255 }
2256
2257 canceled
2258 }
2259
2260 /// Signals that any in-progress editing should be canceled.
2261 ///
2262 /// This method is used to notify listeners (like ActiveThread) that
2263 /// they should cancel any editing operations.
2264 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2265 cx.emit(ThreadEvent::CancelEditing);
2266 }
2267
2268 pub fn feedback(&self) -> Option<ThreadFeedback> {
2269 self.feedback
2270 }
2271
2272 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2273 self.message_feedback.get(&message_id).copied()
2274 }
2275
2276 pub fn report_message_feedback(
2277 &mut self,
2278 message_id: MessageId,
2279 feedback: ThreadFeedback,
2280 cx: &mut Context<Self>,
2281 ) -> Task<Result<()>> {
2282 if self.message_feedback.get(&message_id) == Some(&feedback) {
2283 return Task::ready(Ok(()));
2284 }
2285
2286 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2287 let serialized_thread = self.serialize(cx);
2288 let thread_id = self.id().clone();
2289 let client = self.project.read(cx).client();
2290
2291 let enabled_tool_names: Vec<String> = self
2292 .tools()
2293 .read(cx)
2294 .enabled_tools(cx)
2295 .iter()
2296 .map(|tool| tool.name())
2297 .collect();
2298
2299 self.message_feedback.insert(message_id, feedback);
2300
2301 cx.notify();
2302
2303 let message_content = self
2304 .message(message_id)
2305 .map(|msg| msg.to_string())
2306 .unwrap_or_default();
2307
2308 cx.background_spawn(async move {
2309 let final_project_snapshot = final_project_snapshot.await;
2310 let serialized_thread = serialized_thread.await?;
2311 let thread_data =
2312 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2313
2314 let rating = match feedback {
2315 ThreadFeedback::Positive => "positive",
2316 ThreadFeedback::Negative => "negative",
2317 };
2318 telemetry::event!(
2319 "Assistant Thread Rated",
2320 rating,
2321 thread_id,
2322 enabled_tool_names,
2323 message_id = message_id.0,
2324 message_content,
2325 thread_data,
2326 final_project_snapshot
2327 );
2328 client.telemetry().flush_events().await;
2329
2330 Ok(())
2331 })
2332 }
2333
2334 pub fn report_feedback(
2335 &mut self,
2336 feedback: ThreadFeedback,
2337 cx: &mut Context<Self>,
2338 ) -> Task<Result<()>> {
2339 let last_assistant_message_id = self
2340 .messages
2341 .iter()
2342 .rev()
2343 .find(|msg| msg.role == Role::Assistant)
2344 .map(|msg| msg.id);
2345
2346 if let Some(message_id) = last_assistant_message_id {
2347 self.report_message_feedback(message_id, feedback, cx)
2348 } else {
2349 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2350 let serialized_thread = self.serialize(cx);
2351 let thread_id = self.id().clone();
2352 let client = self.project.read(cx).client();
2353 self.feedback = Some(feedback);
2354 cx.notify();
2355
2356 cx.background_spawn(async move {
2357 let final_project_snapshot = final_project_snapshot.await;
2358 let serialized_thread = serialized_thread.await?;
2359 let thread_data = serde_json::to_value(serialized_thread)
2360 .unwrap_or_else(|_| serde_json::Value::Null);
2361
2362 let rating = match feedback {
2363 ThreadFeedback::Positive => "positive",
2364 ThreadFeedback::Negative => "negative",
2365 };
2366 telemetry::event!(
2367 "Assistant Thread Rated",
2368 rating,
2369 thread_id,
2370 thread_data,
2371 final_project_snapshot
2372 );
2373 client.telemetry().flush_events().await;
2374
2375 Ok(())
2376 })
2377 }
2378 }
2379
2380 /// Create a snapshot of the current project state including git information and unsaved buffers.
2381 fn project_snapshot(
2382 project: Entity<Project>,
2383 cx: &mut Context<Self>,
2384 ) -> Task<Arc<ProjectSnapshot>> {
2385 let git_store = project.read(cx).git_store().clone();
2386 let worktree_snapshots: Vec<_> = project
2387 .read(cx)
2388 .visible_worktrees(cx)
2389 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2390 .collect();
2391
2392 cx.spawn(async move |_, cx| {
2393 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2394
2395 let mut unsaved_buffers = Vec::new();
2396 cx.update(|app_cx| {
2397 let buffer_store = project.read(app_cx).buffer_store();
2398 for buffer_handle in buffer_store.read(app_cx).buffers() {
2399 let buffer = buffer_handle.read(app_cx);
2400 if buffer.is_dirty() {
2401 if let Some(file) = buffer.file() {
2402 let path = file.path().to_string_lossy().to_string();
2403 unsaved_buffers.push(path);
2404 }
2405 }
2406 }
2407 })
2408 .ok();
2409
2410 Arc::new(ProjectSnapshot {
2411 worktree_snapshots,
2412 unsaved_buffer_paths: unsaved_buffers,
2413 timestamp: Utc::now(),
2414 })
2415 })
2416 }
2417
2418 fn worktree_snapshot(
2419 worktree: Entity<project::Worktree>,
2420 git_store: Entity<GitStore>,
2421 cx: &App,
2422 ) -> Task<WorktreeSnapshot> {
2423 cx.spawn(async move |cx| {
2424 // Get worktree path and snapshot
2425 let worktree_info = cx.update(|app_cx| {
2426 let worktree = worktree.read(app_cx);
2427 let path = worktree.abs_path().to_string_lossy().to_string();
2428 let snapshot = worktree.snapshot();
2429 (path, snapshot)
2430 });
2431
2432 let Ok((worktree_path, _snapshot)) = worktree_info else {
2433 return WorktreeSnapshot {
2434 worktree_path: String::new(),
2435 git_state: None,
2436 };
2437 };
2438
2439 let git_state = git_store
2440 .update(cx, |git_store, cx| {
2441 git_store
2442 .repositories()
2443 .values()
2444 .find(|repo| {
2445 repo.read(cx)
2446 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2447 .is_some()
2448 })
2449 .cloned()
2450 })
2451 .ok()
2452 .flatten()
2453 .map(|repo| {
2454 repo.update(cx, |repo, _| {
2455 let current_branch =
2456 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2457 repo.send_job(None, |state, _| async move {
2458 let RepositoryState::Local { backend, .. } = state else {
2459 return GitState {
2460 remote_url: None,
2461 head_sha: None,
2462 current_branch,
2463 diff: None,
2464 };
2465 };
2466
2467 let remote_url = backend.remote_url("origin");
2468 let head_sha = backend.head_sha().await;
2469 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2470
2471 GitState {
2472 remote_url,
2473 head_sha,
2474 current_branch,
2475 diff,
2476 }
2477 })
2478 })
2479 });
2480
2481 let git_state = match git_state {
2482 Some(git_state) => match git_state.ok() {
2483 Some(git_state) => git_state.await.ok(),
2484 None => None,
2485 },
2486 None => None,
2487 };
2488
2489 WorktreeSnapshot {
2490 worktree_path,
2491 git_state,
2492 }
2493 })
2494 }
2495
2496 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2497 let mut markdown = Vec::new();
2498
2499 let summary = self.summary().or_default();
2500 writeln!(markdown, "# {summary}\n")?;
2501
2502 for message in self.messages() {
2503 writeln!(
2504 markdown,
2505 "## {role}\n",
2506 role = match message.role {
2507 Role::User => "User",
2508 Role::Assistant => "Agent",
2509 Role::System => "System",
2510 }
2511 )?;
2512
2513 if !message.loaded_context.text.is_empty() {
2514 writeln!(markdown, "{}", message.loaded_context.text)?;
2515 }
2516
2517 if !message.loaded_context.images.is_empty() {
2518 writeln!(
2519 markdown,
2520 "\n{} images attached as context.\n",
2521 message.loaded_context.images.len()
2522 )?;
2523 }
2524
2525 for segment in &message.segments {
2526 match segment {
2527 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2528 MessageSegment::Thinking { text, .. } => {
2529 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2530 }
2531 MessageSegment::RedactedThinking(_) => {}
2532 }
2533 }
2534
2535 for tool_use in self.tool_uses_for_message(message.id, cx) {
2536 writeln!(
2537 markdown,
2538 "**Use Tool: {} ({})**",
2539 tool_use.name, tool_use.id
2540 )?;
2541 writeln!(markdown, "```json")?;
2542 writeln!(
2543 markdown,
2544 "{}",
2545 serde_json::to_string_pretty(&tool_use.input)?
2546 )?;
2547 writeln!(markdown, "```")?;
2548 }
2549
2550 for tool_result in self.tool_results_for_message(message.id) {
2551 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2552 if tool_result.is_error {
2553 write!(markdown, " (Error)")?;
2554 }
2555
2556 writeln!(markdown, "**\n")?;
2557 match &tool_result.content {
2558 LanguageModelToolResultContent::Text(text)
2559 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2560 text,
2561 ..
2562 }) => {
2563 writeln!(markdown, "{text}")?;
2564 }
2565 LanguageModelToolResultContent::Image(image) => {
2566 writeln!(markdown, "", image.source)?;
2567 }
2568 }
2569
2570 if let Some(output) = tool_result.output.as_ref() {
2571 writeln!(
2572 markdown,
2573 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2574 serde_json::to_string_pretty(output)?
2575 )?;
2576 }
2577 }
2578 }
2579
2580 Ok(String::from_utf8_lossy(&markdown).to_string())
2581 }
2582
2583 pub fn keep_edits_in_range(
2584 &mut self,
2585 buffer: Entity<language::Buffer>,
2586 buffer_range: Range<language::Anchor>,
2587 cx: &mut Context<Self>,
2588 ) {
2589 self.action_log.update(cx, |action_log, cx| {
2590 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2591 });
2592 }
2593
2594 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2595 self.action_log
2596 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2597 }
2598
2599 pub fn reject_edits_in_ranges(
2600 &mut self,
2601 buffer: Entity<language::Buffer>,
2602 buffer_ranges: Vec<Range<language::Anchor>>,
2603 cx: &mut Context<Self>,
2604 ) -> Task<Result<()>> {
2605 self.action_log.update(cx, |action_log, cx| {
2606 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2607 })
2608 }
2609
2610 pub fn action_log(&self) -> &Entity<ActionLog> {
2611 &self.action_log
2612 }
2613
2614 pub fn project(&self) -> &Entity<Project> {
2615 &self.project
2616 }
2617
2618 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2619 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2620 return;
2621 }
2622
2623 let now = Instant::now();
2624 if let Some(last) = self.last_auto_capture_at {
2625 if now.duration_since(last).as_secs() < 10 {
2626 return;
2627 }
2628 }
2629
2630 self.last_auto_capture_at = Some(now);
2631
2632 let thread_id = self.id().clone();
2633 let github_login = self
2634 .project
2635 .read(cx)
2636 .user_store()
2637 .read(cx)
2638 .current_user()
2639 .map(|user| user.github_login.clone());
2640 let client = self.project.read(cx).client();
2641 let serialize_task = self.serialize(cx);
2642
2643 cx.background_executor()
2644 .spawn(async move {
2645 if let Ok(serialized_thread) = serialize_task.await {
2646 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2647 telemetry::event!(
2648 "Agent Thread Auto-Captured",
2649 thread_id = thread_id.to_string(),
2650 thread_data = thread_data,
2651 auto_capture_reason = "tracked_user",
2652 github_login = github_login
2653 );
2654
2655 client.telemetry().flush_events().await;
2656 }
2657 }
2658 })
2659 .detach();
2660 }
2661
2662 pub fn cumulative_token_usage(&self) -> TokenUsage {
2663 self.cumulative_token_usage
2664 }
2665
2666 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2667 let Some(model) = self.configured_model.as_ref() else {
2668 return TotalTokenUsage::default();
2669 };
2670
2671 let max = model.model.max_token_count();
2672
2673 let index = self
2674 .messages
2675 .iter()
2676 .position(|msg| msg.id == message_id)
2677 .unwrap_or(0);
2678
2679 if index == 0 {
2680 return TotalTokenUsage { total: 0, max };
2681 }
2682
2683 let token_usage = &self
2684 .request_token_usage
2685 .get(index - 1)
2686 .cloned()
2687 .unwrap_or_default();
2688
2689 TotalTokenUsage {
2690 total: token_usage.total_tokens() as usize,
2691 max,
2692 }
2693 }
2694
2695 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2696 let model = self.configured_model.as_ref()?;
2697
2698 let max = model.model.max_token_count();
2699
2700 if let Some(exceeded_error) = &self.exceeded_window_error {
2701 if model.model.id() == exceeded_error.model_id {
2702 return Some(TotalTokenUsage {
2703 total: exceeded_error.token_count,
2704 max,
2705 });
2706 }
2707 }
2708
2709 let total = self
2710 .token_usage_at_last_message()
2711 .unwrap_or_default()
2712 .total_tokens() as usize;
2713
2714 Some(TotalTokenUsage { total, max })
2715 }
2716
2717 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2718 self.request_token_usage
2719 .get(self.messages.len().saturating_sub(1))
2720 .or_else(|| self.request_token_usage.last())
2721 .cloned()
2722 }
2723
2724 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2725 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2726 self.request_token_usage
2727 .resize(self.messages.len(), placeholder);
2728
2729 if let Some(last) = self.request_token_usage.last_mut() {
2730 *last = token_usage;
2731 }
2732 }
2733
2734 pub fn deny_tool_use(
2735 &mut self,
2736 tool_use_id: LanguageModelToolUseId,
2737 tool_name: Arc<str>,
2738 window: Option<AnyWindowHandle>,
2739 cx: &mut Context<Self>,
2740 ) {
2741 let err = Err(anyhow::anyhow!(
2742 "Permission to run tool action denied by user"
2743 ));
2744
2745 self.tool_use.insert_tool_output(
2746 tool_use_id.clone(),
2747 tool_name,
2748 err,
2749 self.configured_model.as_ref(),
2750 );
2751 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2752 }
2753}
2754
2755#[derive(Debug, Clone, Error)]
2756pub enum ThreadError {
2757 #[error("Payment required")]
2758 PaymentRequired,
2759 #[error("Model request limit reached")]
2760 ModelRequestLimitReached { plan: Plan },
2761 #[error("Message {header}: {message}")]
2762 Message {
2763 header: SharedString,
2764 message: SharedString,
2765 },
2766}
2767
2768#[derive(Debug, Clone)]
2769pub enum ThreadEvent {
2770 ShowError(ThreadError),
2771 StreamedCompletion,
2772 ReceivedTextChunk,
2773 NewRequest,
2774 StreamedAssistantText(MessageId, String),
2775 StreamedAssistantThinking(MessageId, String),
2776 StreamedToolUse {
2777 tool_use_id: LanguageModelToolUseId,
2778 ui_text: Arc<str>,
2779 input: serde_json::Value,
2780 },
2781 MissingToolUse {
2782 tool_use_id: LanguageModelToolUseId,
2783 ui_text: Arc<str>,
2784 },
2785 InvalidToolInput {
2786 tool_use_id: LanguageModelToolUseId,
2787 ui_text: Arc<str>,
2788 invalid_input_json: Arc<str>,
2789 },
2790 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2791 MessageAdded(MessageId),
2792 MessageEdited(MessageId),
2793 MessageDeleted(MessageId),
2794 SummaryGenerated,
2795 SummaryChanged,
2796 UsePendingTools {
2797 tool_uses: Vec<PendingToolUse>,
2798 },
2799 ToolFinished {
2800 #[allow(unused)]
2801 tool_use_id: LanguageModelToolUseId,
2802 /// The pending tool use that corresponds to this tool.
2803 pending_tool_use: Option<PendingToolUse>,
2804 },
2805 CheckpointChanged,
2806 ToolConfirmationNeeded,
2807 CancelEditing,
2808 CompletionCanceled,
2809}
2810
2811impl EventEmitter<ThreadEvent> for Thread {}
2812
2813struct PendingCompletion {
2814 id: usize,
2815 queue_state: QueueState,
2816 _task: Task<()>,
2817}
2818
2819#[cfg(test)]
2820mod tests {
2821 use super::*;
2822 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2823 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2824 use assistant_tool::ToolRegistry;
2825 use editor::EditorSettings;
2826 use gpui::TestAppContext;
2827 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2828 use project::{FakeFs, Project};
2829 use prompt_store::PromptBuilder;
2830 use serde_json::json;
2831 use settings::{Settings, SettingsStore};
2832 use std::sync::Arc;
2833 use theme::ThemeSettings;
2834 use util::path;
2835 use workspace::Workspace;
2836
2837 #[gpui::test]
2838 async fn test_message_with_context(cx: &mut TestAppContext) {
2839 init_test_settings(cx);
2840
2841 let project = create_test_project(
2842 cx,
2843 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2844 )
2845 .await;
2846
2847 let (_workspace, _thread_store, thread, context_store, model) =
2848 setup_test_environment(cx, project.clone()).await;
2849
2850 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2851 .await
2852 .unwrap();
2853
2854 let context =
2855 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2856 let loaded_context = cx
2857 .update(|cx| load_context(vec![context], &project, &None, cx))
2858 .await;
2859
2860 // Insert user message with context
2861 let message_id = thread.update(cx, |thread, cx| {
2862 thread.insert_user_message(
2863 "Please explain this code",
2864 loaded_context,
2865 None,
2866 Vec::new(),
2867 cx,
2868 )
2869 });
2870
2871 // Check content and context in message object
2872 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2873
2874 // Use different path format strings based on platform for the test
2875 #[cfg(windows)]
2876 let path_part = r"test\code.rs";
2877 #[cfg(not(windows))]
2878 let path_part = "test/code.rs";
2879
2880 let expected_context = format!(
2881 r#"
2882<context>
2883The following items were attached by the user. They are up-to-date and don't need to be re-read.
2884
2885<files>
2886```rs {path_part}
2887fn main() {{
2888 println!("Hello, world!");
2889}}
2890```
2891</files>
2892</context>
2893"#
2894 );
2895
2896 assert_eq!(message.role, Role::User);
2897 assert_eq!(message.segments.len(), 1);
2898 assert_eq!(
2899 message.segments[0],
2900 MessageSegment::Text("Please explain this code".to_string())
2901 );
2902 assert_eq!(message.loaded_context.text, expected_context);
2903
2904 // Check message in request
2905 let request = thread.update(cx, |thread, cx| {
2906 thread.to_completion_request(model.clone(), cx)
2907 });
2908
2909 assert_eq!(request.messages.len(), 2);
2910 let expected_full_message = format!("{}Please explain this code", expected_context);
2911 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2912 }
2913
2914 #[gpui::test]
2915 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2916 init_test_settings(cx);
2917
2918 let project = create_test_project(
2919 cx,
2920 json!({
2921 "file1.rs": "fn function1() {}\n",
2922 "file2.rs": "fn function2() {}\n",
2923 "file3.rs": "fn function3() {}\n",
2924 "file4.rs": "fn function4() {}\n",
2925 }),
2926 )
2927 .await;
2928
2929 let (_, _thread_store, thread, context_store, model) =
2930 setup_test_environment(cx, project.clone()).await;
2931
2932 // First message with context 1
2933 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2934 .await
2935 .unwrap();
2936 let new_contexts = context_store.update(cx, |store, cx| {
2937 store.new_context_for_thread(thread.read(cx), None)
2938 });
2939 assert_eq!(new_contexts.len(), 1);
2940 let loaded_context = cx
2941 .update(|cx| load_context(new_contexts, &project, &None, cx))
2942 .await;
2943 let message1_id = thread.update(cx, |thread, cx| {
2944 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2945 });
2946
2947 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2948 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2949 .await
2950 .unwrap();
2951 let new_contexts = context_store.update(cx, |store, cx| {
2952 store.new_context_for_thread(thread.read(cx), None)
2953 });
2954 assert_eq!(new_contexts.len(), 1);
2955 let loaded_context = cx
2956 .update(|cx| load_context(new_contexts, &project, &None, cx))
2957 .await;
2958 let message2_id = thread.update(cx, |thread, cx| {
2959 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2960 });
2961
2962 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2963 //
2964 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2965 .await
2966 .unwrap();
2967 let new_contexts = context_store.update(cx, |store, cx| {
2968 store.new_context_for_thread(thread.read(cx), None)
2969 });
2970 assert_eq!(new_contexts.len(), 1);
2971 let loaded_context = cx
2972 .update(|cx| load_context(new_contexts, &project, &None, cx))
2973 .await;
2974 let message3_id = thread.update(cx, |thread, cx| {
2975 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2976 });
2977
2978 // Check what contexts are included in each message
2979 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2980 (
2981 thread.message(message1_id).unwrap().clone(),
2982 thread.message(message2_id).unwrap().clone(),
2983 thread.message(message3_id).unwrap().clone(),
2984 )
2985 });
2986
2987 // First message should include context 1
2988 assert!(message1.loaded_context.text.contains("file1.rs"));
2989
2990 // Second message should include only context 2 (not 1)
2991 assert!(!message2.loaded_context.text.contains("file1.rs"));
2992 assert!(message2.loaded_context.text.contains("file2.rs"));
2993
2994 // Third message should include only context 3 (not 1 or 2)
2995 assert!(!message3.loaded_context.text.contains("file1.rs"));
2996 assert!(!message3.loaded_context.text.contains("file2.rs"));
2997 assert!(message3.loaded_context.text.contains("file3.rs"));
2998
2999 // Check entire request to make sure all contexts are properly included
3000 let request = thread.update(cx, |thread, cx| {
3001 thread.to_completion_request(model.clone(), cx)
3002 });
3003
3004 // The request should contain all 3 messages
3005 assert_eq!(request.messages.len(), 4);
3006
3007 // Check that the contexts are properly formatted in each message
3008 assert!(request.messages[1].string_contents().contains("file1.rs"));
3009 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3010 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3011
3012 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3013 assert!(request.messages[2].string_contents().contains("file2.rs"));
3014 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3015
3016 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3017 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3018 assert!(request.messages[3].string_contents().contains("file3.rs"));
3019
3020 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3021 .await
3022 .unwrap();
3023 let new_contexts = context_store.update(cx, |store, cx| {
3024 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3025 });
3026 assert_eq!(new_contexts.len(), 3);
3027 let loaded_context = cx
3028 .update(|cx| load_context(new_contexts, &project, &None, cx))
3029 .await
3030 .loaded_context;
3031
3032 assert!(!loaded_context.text.contains("file1.rs"));
3033 assert!(loaded_context.text.contains("file2.rs"));
3034 assert!(loaded_context.text.contains("file3.rs"));
3035 assert!(loaded_context.text.contains("file4.rs"));
3036
3037 let new_contexts = context_store.update(cx, |store, cx| {
3038 // Remove file4.rs
3039 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3040 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3041 });
3042 assert_eq!(new_contexts.len(), 2);
3043 let loaded_context = cx
3044 .update(|cx| load_context(new_contexts, &project, &None, cx))
3045 .await
3046 .loaded_context;
3047
3048 assert!(!loaded_context.text.contains("file1.rs"));
3049 assert!(loaded_context.text.contains("file2.rs"));
3050 assert!(loaded_context.text.contains("file3.rs"));
3051 assert!(!loaded_context.text.contains("file4.rs"));
3052
3053 let new_contexts = context_store.update(cx, |store, cx| {
3054 // Remove file3.rs
3055 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3056 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3057 });
3058 assert_eq!(new_contexts.len(), 1);
3059 let loaded_context = cx
3060 .update(|cx| load_context(new_contexts, &project, &None, cx))
3061 .await
3062 .loaded_context;
3063
3064 assert!(!loaded_context.text.contains("file1.rs"));
3065 assert!(loaded_context.text.contains("file2.rs"));
3066 assert!(!loaded_context.text.contains("file3.rs"));
3067 assert!(!loaded_context.text.contains("file4.rs"));
3068 }
3069
3070 #[gpui::test]
3071 async fn test_message_without_files(cx: &mut TestAppContext) {
3072 init_test_settings(cx);
3073
3074 let project = create_test_project(
3075 cx,
3076 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3077 )
3078 .await;
3079
3080 let (_, _thread_store, thread, _context_store, model) =
3081 setup_test_environment(cx, project.clone()).await;
3082
3083 // Insert user message without any context (empty context vector)
3084 let message_id = thread.update(cx, |thread, cx| {
3085 thread.insert_user_message(
3086 "What is the best way to learn Rust?",
3087 ContextLoadResult::default(),
3088 None,
3089 Vec::new(),
3090 cx,
3091 )
3092 });
3093
3094 // Check content and context in message object
3095 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3096
3097 // Context should be empty when no files are included
3098 assert_eq!(message.role, Role::User);
3099 assert_eq!(message.segments.len(), 1);
3100 assert_eq!(
3101 message.segments[0],
3102 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3103 );
3104 assert_eq!(message.loaded_context.text, "");
3105
3106 // Check message in request
3107 let request = thread.update(cx, |thread, cx| {
3108 thread.to_completion_request(model.clone(), cx)
3109 });
3110
3111 assert_eq!(request.messages.len(), 2);
3112 assert_eq!(
3113 request.messages[1].string_contents(),
3114 "What is the best way to learn Rust?"
3115 );
3116
3117 // Add second message, also without context
3118 let message2_id = thread.update(cx, |thread, cx| {
3119 thread.insert_user_message(
3120 "Are there any good books?",
3121 ContextLoadResult::default(),
3122 None,
3123 Vec::new(),
3124 cx,
3125 )
3126 });
3127
3128 let message2 =
3129 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3130 assert_eq!(message2.loaded_context.text, "");
3131
3132 // Check that both messages appear in the request
3133 let request = thread.update(cx, |thread, cx| {
3134 thread.to_completion_request(model.clone(), cx)
3135 });
3136
3137 assert_eq!(request.messages.len(), 3);
3138 assert_eq!(
3139 request.messages[1].string_contents(),
3140 "What is the best way to learn Rust?"
3141 );
3142 assert_eq!(
3143 request.messages[2].string_contents(),
3144 "Are there any good books?"
3145 );
3146 }
3147
3148 #[gpui::test]
3149 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3150 init_test_settings(cx);
3151
3152 let project = create_test_project(
3153 cx,
3154 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3155 )
3156 .await;
3157
3158 let (_workspace, _thread_store, thread, context_store, model) =
3159 setup_test_environment(cx, project.clone()).await;
3160
3161 // Open buffer and add it to context
3162 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3163 .await
3164 .unwrap();
3165
3166 let context =
3167 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3168 let loaded_context = cx
3169 .update(|cx| load_context(vec![context], &project, &None, cx))
3170 .await;
3171
3172 // Insert user message with the buffer as context
3173 thread.update(cx, |thread, cx| {
3174 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3175 });
3176
3177 // Create a request and check that it doesn't have a stale buffer warning yet
3178 let initial_request = thread.update(cx, |thread, cx| {
3179 thread.to_completion_request(model.clone(), cx)
3180 });
3181
3182 // Make sure we don't have a stale file warning yet
3183 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3184 msg.string_contents()
3185 .contains("These files changed since last read:")
3186 });
3187 assert!(
3188 !has_stale_warning,
3189 "Should not have stale buffer warning before buffer is modified"
3190 );
3191
3192 // Modify the buffer
3193 buffer.update(cx, |buffer, cx| {
3194 // Find a position at the end of line 1
3195 buffer.edit(
3196 [(1..1, "\n println!(\"Added a new line\");\n")],
3197 None,
3198 cx,
3199 );
3200 });
3201
3202 // Insert another user message without context
3203 thread.update(cx, |thread, cx| {
3204 thread.insert_user_message(
3205 "What does the code do now?",
3206 ContextLoadResult::default(),
3207 None,
3208 Vec::new(),
3209 cx,
3210 )
3211 });
3212
3213 // Create a new request and check for the stale buffer warning
3214 let new_request = thread.update(cx, |thread, cx| {
3215 thread.to_completion_request(model.clone(), cx)
3216 });
3217
3218 // We should have a stale file warning as the last message
3219 let last_message = new_request
3220 .messages
3221 .last()
3222 .expect("Request should have messages");
3223
3224 // The last message should be the stale buffer notification
3225 assert_eq!(last_message.role, Role::User);
3226
3227 // Check the exact content of the message
3228 let expected_content = "These files changed since last read:\n- code.rs\n";
3229 assert_eq!(
3230 last_message.string_contents(),
3231 expected_content,
3232 "Last message should be exactly the stale buffer notification"
3233 );
3234 }
3235
3236 #[gpui::test]
3237 async fn test_temperature_setting(cx: &mut TestAppContext) {
3238 init_test_settings(cx);
3239
3240 let project = create_test_project(
3241 cx,
3242 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3243 )
3244 .await;
3245
3246 let (_workspace, _thread_store, thread, _context_store, model) =
3247 setup_test_environment(cx, project.clone()).await;
3248
3249 // Both model and provider
3250 cx.update(|cx| {
3251 AssistantSettings::override_global(
3252 AssistantSettings {
3253 model_parameters: vec![LanguageModelParameters {
3254 provider: Some(model.provider_id().0.to_string().into()),
3255 model: Some(model.id().0.clone()),
3256 temperature: Some(0.66),
3257 }],
3258 ..AssistantSettings::get_global(cx).clone()
3259 },
3260 cx,
3261 );
3262 });
3263
3264 let request = thread.update(cx, |thread, cx| {
3265 thread.to_completion_request(model.clone(), cx)
3266 });
3267 assert_eq!(request.temperature, Some(0.66));
3268
3269 // Only model
3270 cx.update(|cx| {
3271 AssistantSettings::override_global(
3272 AssistantSettings {
3273 model_parameters: vec![LanguageModelParameters {
3274 provider: None,
3275 model: Some(model.id().0.clone()),
3276 temperature: Some(0.66),
3277 }],
3278 ..AssistantSettings::get_global(cx).clone()
3279 },
3280 cx,
3281 );
3282 });
3283
3284 let request = thread.update(cx, |thread, cx| {
3285 thread.to_completion_request(model.clone(), cx)
3286 });
3287 assert_eq!(request.temperature, Some(0.66));
3288
3289 // Only provider
3290 cx.update(|cx| {
3291 AssistantSettings::override_global(
3292 AssistantSettings {
3293 model_parameters: vec![LanguageModelParameters {
3294 provider: Some(model.provider_id().0.to_string().into()),
3295 model: None,
3296 temperature: Some(0.66),
3297 }],
3298 ..AssistantSettings::get_global(cx).clone()
3299 },
3300 cx,
3301 );
3302 });
3303
3304 let request = thread.update(cx, |thread, cx| {
3305 thread.to_completion_request(model.clone(), cx)
3306 });
3307 assert_eq!(request.temperature, Some(0.66));
3308
3309 // Same model name, different provider
3310 cx.update(|cx| {
3311 AssistantSettings::override_global(
3312 AssistantSettings {
3313 model_parameters: vec![LanguageModelParameters {
3314 provider: Some("anthropic".into()),
3315 model: Some(model.id().0.clone()),
3316 temperature: Some(0.66),
3317 }],
3318 ..AssistantSettings::get_global(cx).clone()
3319 },
3320 cx,
3321 );
3322 });
3323
3324 let request = thread.update(cx, |thread, cx| {
3325 thread.to_completion_request(model.clone(), cx)
3326 });
3327 assert_eq!(request.temperature, None);
3328 }
3329
3330 #[gpui::test]
3331 async fn test_thread_summary(cx: &mut TestAppContext) {
3332 init_test_settings(cx);
3333
3334 let project = create_test_project(cx, json!({})).await;
3335
3336 let (_, _thread_store, thread, _context_store, model) =
3337 setup_test_environment(cx, project.clone()).await;
3338
3339 // Initial state should be pending
3340 thread.read_with(cx, |thread, _| {
3341 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3342 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3343 });
3344
3345 // Manually setting the summary should not be allowed in this state
3346 thread.update(cx, |thread, cx| {
3347 thread.set_summary("This should not work", cx);
3348 });
3349
3350 thread.read_with(cx, |thread, _| {
3351 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3352 });
3353
3354 // Send a message
3355 thread.update(cx, |thread, cx| {
3356 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3357 thread.send_to_model(model.clone(), None, cx);
3358 });
3359
3360 let fake_model = model.as_fake();
3361 simulate_successful_response(&fake_model, cx);
3362
3363 // Should start generating summary when there are >= 2 messages
3364 thread.read_with(cx, |thread, _| {
3365 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3366 });
3367
3368 // Should not be able to set the summary while generating
3369 thread.update(cx, |thread, cx| {
3370 thread.set_summary("This should not work either", cx);
3371 });
3372
3373 thread.read_with(cx, |thread, _| {
3374 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3375 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3376 });
3377
3378 cx.run_until_parked();
3379 fake_model.stream_last_completion_response("Brief".into());
3380 fake_model.stream_last_completion_response(" Introduction".into());
3381 fake_model.end_last_completion_stream();
3382 cx.run_until_parked();
3383
3384 // Summary should be set
3385 thread.read_with(cx, |thread, _| {
3386 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3387 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3388 });
3389
3390 // Now we should be able to set a summary
3391 thread.update(cx, |thread, cx| {
3392 thread.set_summary("Brief Intro", cx);
3393 });
3394
3395 thread.read_with(cx, |thread, _| {
3396 assert_eq!(thread.summary().or_default(), "Brief Intro");
3397 });
3398
3399 // Test setting an empty summary (should default to DEFAULT)
3400 thread.update(cx, |thread, cx| {
3401 thread.set_summary("", cx);
3402 });
3403
3404 thread.read_with(cx, |thread, _| {
3405 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3406 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3407 });
3408 }
3409
3410 #[gpui::test]
3411 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3412 init_test_settings(cx);
3413
3414 let project = create_test_project(cx, json!({})).await;
3415
3416 let (_, _thread_store, thread, _context_store, model) =
3417 setup_test_environment(cx, project.clone()).await;
3418
3419 test_summarize_error(&model, &thread, cx);
3420
3421 // Now we should be able to set a summary
3422 thread.update(cx, |thread, cx| {
3423 thread.set_summary("Brief Intro", cx);
3424 });
3425
3426 thread.read_with(cx, |thread, _| {
3427 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3428 assert_eq!(thread.summary().or_default(), "Brief Intro");
3429 });
3430 }
3431
3432 #[gpui::test]
3433 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3434 init_test_settings(cx);
3435
3436 let project = create_test_project(cx, json!({})).await;
3437
3438 let (_, _thread_store, thread, _context_store, model) =
3439 setup_test_environment(cx, project.clone()).await;
3440
3441 test_summarize_error(&model, &thread, cx);
3442
3443 // Sending another message should not trigger another summarize request
3444 thread.update(cx, |thread, cx| {
3445 thread.insert_user_message(
3446 "How are you?",
3447 ContextLoadResult::default(),
3448 None,
3449 vec![],
3450 cx,
3451 );
3452 thread.send_to_model(model.clone(), None, cx);
3453 });
3454
3455 let fake_model = model.as_fake();
3456 simulate_successful_response(&fake_model, cx);
3457
3458 thread.read_with(cx, |thread, _| {
3459 // State is still Error, not Generating
3460 assert!(matches!(thread.summary(), ThreadSummary::Error));
3461 });
3462
3463 // But the summarize request can be invoked manually
3464 thread.update(cx, |thread, cx| {
3465 thread.summarize(cx);
3466 });
3467
3468 thread.read_with(cx, |thread, _| {
3469 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3470 });
3471
3472 cx.run_until_parked();
3473 fake_model.stream_last_completion_response("A successful summary".into());
3474 fake_model.end_last_completion_stream();
3475 cx.run_until_parked();
3476
3477 thread.read_with(cx, |thread, _| {
3478 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3479 assert_eq!(thread.summary().or_default(), "A successful summary");
3480 });
3481 }
3482
3483 fn test_summarize_error(
3484 model: &Arc<dyn LanguageModel>,
3485 thread: &Entity<Thread>,
3486 cx: &mut TestAppContext,
3487 ) {
3488 thread.update(cx, |thread, cx| {
3489 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3490 thread.send_to_model(model.clone(), None, cx);
3491 });
3492
3493 let fake_model = model.as_fake();
3494 simulate_successful_response(&fake_model, cx);
3495
3496 thread.read_with(cx, |thread, _| {
3497 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3498 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3499 });
3500
3501 // Simulate summary request ending
3502 cx.run_until_parked();
3503 fake_model.end_last_completion_stream();
3504 cx.run_until_parked();
3505
3506 // State is set to Error and default message
3507 thread.read_with(cx, |thread, _| {
3508 assert!(matches!(thread.summary(), ThreadSummary::Error));
3509 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3510 });
3511 }
3512
3513 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3514 cx.run_until_parked();
3515 fake_model.stream_last_completion_response("Assistant response".into());
3516 fake_model.end_last_completion_stream();
3517 cx.run_until_parked();
3518 }
3519
3520 fn init_test_settings(cx: &mut TestAppContext) {
3521 cx.update(|cx| {
3522 let settings_store = SettingsStore::test(cx);
3523 cx.set_global(settings_store);
3524 language::init(cx);
3525 Project::init_settings(cx);
3526 AssistantSettings::register(cx);
3527 prompt_store::init(cx);
3528 thread_store::init(cx);
3529 workspace::init_settings(cx);
3530 language_model::init_settings(cx);
3531 ThemeSettings::register(cx);
3532 EditorSettings::register(cx);
3533 ToolRegistry::default_global(cx);
3534 });
3535 }
3536
3537 // Helper to create a test project with test files
3538 async fn create_test_project(
3539 cx: &mut TestAppContext,
3540 files: serde_json::Value,
3541 ) -> Entity<Project> {
3542 let fs = FakeFs::new(cx.executor());
3543 fs.insert_tree(path!("/test"), files).await;
3544 Project::test(fs, [path!("/test").as_ref()], cx).await
3545 }
3546
3547 async fn setup_test_environment(
3548 cx: &mut TestAppContext,
3549 project: Entity<Project>,
3550 ) -> (
3551 Entity<Workspace>,
3552 Entity<ThreadStore>,
3553 Entity<Thread>,
3554 Entity<ContextStore>,
3555 Arc<dyn LanguageModel>,
3556 ) {
3557 let (workspace, cx) =
3558 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3559
3560 let thread_store = cx
3561 .update(|_, cx| {
3562 ThreadStore::load(
3563 project.clone(),
3564 cx.new(|_| ToolWorkingSet::default()),
3565 None,
3566 Arc::new(PromptBuilder::new(None).unwrap()),
3567 cx,
3568 )
3569 })
3570 .await
3571 .unwrap();
3572
3573 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3574 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3575
3576 let provider = Arc::new(FakeLanguageModelProvider);
3577 let model = provider.test_model();
3578 let model: Arc<dyn LanguageModel> = Arc::new(model);
3579
3580 cx.update(|_, cx| {
3581 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3582 registry.set_default_model(
3583 Some(ConfiguredModel {
3584 provider: provider.clone(),
3585 model: model.clone(),
3586 }),
3587 cx,
3588 );
3589 registry.set_thread_summary_model(
3590 Some(ConfiguredModel {
3591 provider,
3592 model: model.clone(),
3593 }),
3594 cx,
3595 );
3596 })
3597 });
3598
3599 (workspace, thread_store, thread, context_store, model)
3600 }
3601
3602 async fn add_file_to_context(
3603 project: &Entity<Project>,
3604 context_store: &Entity<ContextStore>,
3605 path: &str,
3606 cx: &mut TestAppContext,
3607 ) -> Result<Entity<language::Buffer>> {
3608 let buffer_path = project
3609 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3610 .unwrap();
3611
3612 let buffer = project
3613 .update(cx, |project, cx| {
3614 project.open_buffer(buffer_path.clone(), cx)
3615 })
3616 .await
3617 .unwrap();
3618
3619 context_store.update(cx, |context_store, cx| {
3620 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3621 });
3622
3623 Ok(buffer)
3624 }
3625}