1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(str) => Some(str),
885 LanguageModelToolResultContent::Image(_) => {
886 // TODO: We should display image
887 None
888 }
889 }
890 }
891
892 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
893 self.tool_use.tool_result_card(id).cloned()
894 }
895
896 /// Return tools that are both enabled and supported by the model
897 pub fn available_tools(
898 &self,
899 cx: &App,
900 model: Arc<dyn LanguageModel>,
901 ) -> Vec<LanguageModelRequestTool> {
902 if model.supports_tools() {
903 self.tools()
904 .read(cx)
905 .enabled_tools(cx)
906 .into_iter()
907 .filter_map(|tool| {
908 // Skip tools that cannot be supported
909 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
910 Some(LanguageModelRequestTool {
911 name: tool.name(),
912 description: tool.description(),
913 input_schema,
914 })
915 })
916 .collect()
917 } else {
918 Vec::default()
919 }
920 }
921
922 pub fn insert_user_message(
923 &mut self,
924 text: impl Into<String>,
925 loaded_context: ContextLoadResult,
926 git_checkpoint: Option<GitStoreCheckpoint>,
927 creases: Vec<MessageCrease>,
928 cx: &mut Context<Self>,
929 ) -> MessageId {
930 if !loaded_context.referenced_buffers.is_empty() {
931 self.action_log.update(cx, |log, cx| {
932 for buffer in loaded_context.referenced_buffers {
933 log.buffer_read(buffer, cx);
934 }
935 });
936 }
937
938 let message_id = self.insert_message(
939 Role::User,
940 vec![MessageSegment::Text(text.into())],
941 loaded_context.loaded_context,
942 creases,
943 cx,
944 );
945
946 if let Some(git_checkpoint) = git_checkpoint {
947 self.pending_checkpoint = Some(ThreadCheckpoint {
948 message_id,
949 git_checkpoint,
950 });
951 }
952
953 self.auto_capture_telemetry(cx);
954
955 message_id
956 }
957
958 pub fn insert_assistant_message(
959 &mut self,
960 segments: Vec<MessageSegment>,
961 cx: &mut Context<Self>,
962 ) -> MessageId {
963 self.insert_message(
964 Role::Assistant,
965 segments,
966 LoadedContext::default(),
967 Vec::new(),
968 cx,
969 )
970 }
971
972 pub fn insert_message(
973 &mut self,
974 role: Role,
975 segments: Vec<MessageSegment>,
976 loaded_context: LoadedContext,
977 creases: Vec<MessageCrease>,
978 cx: &mut Context<Self>,
979 ) -> MessageId {
980 let id = self.next_message_id.post_inc();
981 self.messages.push(Message {
982 id,
983 role,
984 segments,
985 loaded_context,
986 creases,
987 });
988 self.touch_updated_at();
989 cx.emit(ThreadEvent::MessageAdded(id));
990 id
991 }
992
993 pub fn edit_message(
994 &mut self,
995 id: MessageId,
996 new_role: Role,
997 new_segments: Vec<MessageSegment>,
998 loaded_context: Option<LoadedContext>,
999 cx: &mut Context<Self>,
1000 ) -> bool {
1001 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1002 return false;
1003 };
1004 message.role = new_role;
1005 message.segments = new_segments;
1006 if let Some(context) = loaded_context {
1007 message.loaded_context = context;
1008 }
1009 self.touch_updated_at();
1010 cx.emit(ThreadEvent::MessageEdited(id));
1011 true
1012 }
1013
1014 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1015 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1016 return false;
1017 };
1018 self.messages.remove(index);
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageDeleted(id));
1021 true
1022 }
1023
1024 /// Returns the representation of this [`Thread`] in a textual form.
1025 ///
1026 /// This is the representation we use when attaching a thread as context to another thread.
1027 pub fn text(&self) -> String {
1028 let mut text = String::new();
1029
1030 for message in &self.messages {
1031 text.push_str(match message.role {
1032 language_model::Role::User => "User:",
1033 language_model::Role::Assistant => "Agent:",
1034 language_model::Role::System => "System:",
1035 });
1036 text.push('\n');
1037
1038 for segment in &message.segments {
1039 match segment {
1040 MessageSegment::Text(content) => text.push_str(content),
1041 MessageSegment::Thinking { text: content, .. } => {
1042 text.push_str(&format!("<think>{}</think>", content))
1043 }
1044 MessageSegment::RedactedThinking(_) => {}
1045 }
1046 }
1047 text.push('\n');
1048 }
1049
1050 text
1051 }
1052
1053 /// Serializes this thread into a format for storage or telemetry.
1054 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1055 let initial_project_snapshot = self.initial_project_snapshot.clone();
1056 cx.spawn(async move |this, cx| {
1057 let initial_project_snapshot = initial_project_snapshot.await;
1058 this.read_with(cx, |this, cx| SerializedThread {
1059 version: SerializedThread::VERSION.to_string(),
1060 summary: this.summary().or_default(),
1061 updated_at: this.updated_at(),
1062 messages: this
1063 .messages()
1064 .map(|message| SerializedMessage {
1065 id: message.id,
1066 role: message.role,
1067 segments: message
1068 .segments
1069 .iter()
1070 .map(|segment| match segment {
1071 MessageSegment::Text(text) => {
1072 SerializedMessageSegment::Text { text: text.clone() }
1073 }
1074 MessageSegment::Thinking { text, signature } => {
1075 SerializedMessageSegment::Thinking {
1076 text: text.clone(),
1077 signature: signature.clone(),
1078 }
1079 }
1080 MessageSegment::RedactedThinking(data) => {
1081 SerializedMessageSegment::RedactedThinking {
1082 data: data.clone(),
1083 }
1084 }
1085 })
1086 .collect(),
1087 tool_uses: this
1088 .tool_uses_for_message(message.id, cx)
1089 .into_iter()
1090 .map(|tool_use| SerializedToolUse {
1091 id: tool_use.id,
1092 name: tool_use.name,
1093 input: tool_use.input,
1094 })
1095 .collect(),
1096 tool_results: this
1097 .tool_results_for_message(message.id)
1098 .into_iter()
1099 .map(|tool_result| SerializedToolResult {
1100 tool_use_id: tool_result.tool_use_id.clone(),
1101 is_error: tool_result.is_error,
1102 content: tool_result.content.clone(),
1103 output: tool_result.output.clone(),
1104 })
1105 .collect(),
1106 context: message.loaded_context.text.clone(),
1107 creases: message
1108 .creases
1109 .iter()
1110 .map(|crease| SerializedCrease {
1111 start: crease.range.start,
1112 end: crease.range.end,
1113 icon_path: crease.metadata.icon_path.clone(),
1114 label: crease.metadata.label.clone(),
1115 })
1116 .collect(),
1117 })
1118 .collect(),
1119 initial_project_snapshot,
1120 cumulative_token_usage: this.cumulative_token_usage,
1121 request_token_usage: this.request_token_usage.clone(),
1122 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1123 exceeded_window_error: this.exceeded_window_error.clone(),
1124 model: this
1125 .configured_model
1126 .as_ref()
1127 .map(|model| SerializedLanguageModel {
1128 provider: model.provider.id().0.to_string(),
1129 model: model.model.id().0.to_string(),
1130 }),
1131 completion_mode: Some(this.completion_mode),
1132 })
1133 })
1134 }
1135
1136 pub fn remaining_turns(&self) -> u32 {
1137 self.remaining_turns
1138 }
1139
1140 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1141 self.remaining_turns = remaining_turns;
1142 }
1143
1144 pub fn send_to_model(
1145 &mut self,
1146 model: Arc<dyn LanguageModel>,
1147 window: Option<AnyWindowHandle>,
1148 cx: &mut Context<Self>,
1149 ) {
1150 if self.remaining_turns == 0 {
1151 return;
1152 }
1153
1154 self.remaining_turns -= 1;
1155
1156 let request = self.to_completion_request(model.clone(), cx);
1157
1158 self.stream_completion(request, model, window, cx);
1159 }
1160
1161 pub fn used_tools_since_last_user_message(&self) -> bool {
1162 for message in self.messages.iter().rev() {
1163 if self.tool_use.message_has_tool_results(message.id) {
1164 return true;
1165 } else if message.role == Role::User {
1166 return false;
1167 }
1168 }
1169
1170 false
1171 }
1172
1173 pub fn to_completion_request(
1174 &self,
1175 model: Arc<dyn LanguageModel>,
1176 cx: &mut Context<Self>,
1177 ) -> LanguageModelRequest {
1178 let mut request = LanguageModelRequest {
1179 thread_id: Some(self.id.to_string()),
1180 prompt_id: Some(self.last_prompt_id.to_string()),
1181 mode: None,
1182 messages: vec![],
1183 tools: Vec::new(),
1184 tool_choice: None,
1185 stop: Vec::new(),
1186 temperature: AssistantSettings::temperature_for_model(&model, cx),
1187 };
1188
1189 let available_tools = self.available_tools(cx, model.clone());
1190 let available_tool_names = available_tools
1191 .iter()
1192 .map(|tool| tool.name.clone())
1193 .collect();
1194
1195 let model_context = &ModelContext {
1196 available_tools: available_tool_names,
1197 };
1198
1199 if let Some(project_context) = self.project_context.borrow().as_ref() {
1200 match self
1201 .prompt_builder
1202 .generate_assistant_system_prompt(project_context, model_context)
1203 {
1204 Err(err) => {
1205 let message = format!("{err:?}").into();
1206 log::error!("{message}");
1207 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1208 header: "Error generating system prompt".into(),
1209 message,
1210 }));
1211 }
1212 Ok(system_prompt) => {
1213 request.messages.push(LanguageModelRequestMessage {
1214 role: Role::System,
1215 content: vec![MessageContent::Text(system_prompt)],
1216 cache: true,
1217 });
1218 }
1219 }
1220 } else {
1221 let message = "Context for system prompt unexpectedly not ready.".into();
1222 log::error!("{message}");
1223 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1224 header: "Error generating system prompt".into(),
1225 message,
1226 }));
1227 }
1228
1229 let mut message_ix_to_cache = None;
1230 for message in &self.messages {
1231 let mut request_message = LanguageModelRequestMessage {
1232 role: message.role,
1233 content: Vec::new(),
1234 cache: false,
1235 };
1236
1237 message
1238 .loaded_context
1239 .add_to_request_message(&mut request_message);
1240
1241 for segment in &message.segments {
1242 match segment {
1243 MessageSegment::Text(text) => {
1244 if !text.is_empty() {
1245 request_message
1246 .content
1247 .push(MessageContent::Text(text.into()));
1248 }
1249 }
1250 MessageSegment::Thinking { text, signature } => {
1251 if !text.is_empty() {
1252 request_message.content.push(MessageContent::Thinking {
1253 text: text.into(),
1254 signature: signature.clone(),
1255 });
1256 }
1257 }
1258 MessageSegment::RedactedThinking(data) => {
1259 request_message
1260 .content
1261 .push(MessageContent::RedactedThinking(data.clone()));
1262 }
1263 };
1264 }
1265
1266 let mut cache_message = true;
1267 let mut tool_results_message = LanguageModelRequestMessage {
1268 role: Role::User,
1269 content: Vec::new(),
1270 cache: false,
1271 };
1272 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1273 if let Some(tool_result) = tool_result {
1274 request_message
1275 .content
1276 .push(MessageContent::ToolUse(tool_use.clone()));
1277 tool_results_message
1278 .content
1279 .push(MessageContent::ToolResult(LanguageModelToolResult {
1280 tool_use_id: tool_use.id.clone(),
1281 tool_name: tool_result.tool_name.clone(),
1282 is_error: tool_result.is_error,
1283 content: if tool_result.content.is_empty() {
1284 // Surprisingly, the API fails if we return an empty string here.
1285 // It thinks we are sending a tool use without a tool result.
1286 "<Tool returned an empty string>".into()
1287 } else {
1288 tool_result.content.clone()
1289 },
1290 output: None,
1291 }));
1292 } else {
1293 cache_message = false;
1294 log::debug!(
1295 "skipped tool use {:?} because it is still pending",
1296 tool_use
1297 );
1298 }
1299 }
1300
1301 if cache_message {
1302 message_ix_to_cache = Some(request.messages.len());
1303 }
1304 request.messages.push(request_message);
1305
1306 if !tool_results_message.content.is_empty() {
1307 if cache_message {
1308 message_ix_to_cache = Some(request.messages.len());
1309 }
1310 request.messages.push(tool_results_message);
1311 }
1312 }
1313
1314 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1315 if let Some(message_ix_to_cache) = message_ix_to_cache {
1316 request.messages[message_ix_to_cache].cache = true;
1317 }
1318
1319 self.attached_tracked_files_state(&mut request.messages, cx);
1320
1321 request.tools = available_tools;
1322 request.mode = if model.supports_max_mode() {
1323 Some(self.completion_mode.into())
1324 } else {
1325 Some(CompletionMode::Normal.into())
1326 };
1327
1328 request
1329 }
1330
1331 fn to_summarize_request(
1332 &self,
1333 model: &Arc<dyn LanguageModel>,
1334 added_user_message: String,
1335 cx: &App,
1336 ) -> LanguageModelRequest {
1337 let mut request = LanguageModelRequest {
1338 thread_id: None,
1339 prompt_id: None,
1340 mode: None,
1341 messages: vec![],
1342 tools: Vec::new(),
1343 tool_choice: None,
1344 stop: Vec::new(),
1345 temperature: AssistantSettings::temperature_for_model(model, cx),
1346 };
1347
1348 for message in &self.messages {
1349 let mut request_message = LanguageModelRequestMessage {
1350 role: message.role,
1351 content: Vec::new(),
1352 cache: false,
1353 };
1354
1355 for segment in &message.segments {
1356 match segment {
1357 MessageSegment::Text(text) => request_message
1358 .content
1359 .push(MessageContent::Text(text.clone())),
1360 MessageSegment::Thinking { .. } => {}
1361 MessageSegment::RedactedThinking(_) => {}
1362 }
1363 }
1364
1365 if request_message.content.is_empty() {
1366 continue;
1367 }
1368
1369 request.messages.push(request_message);
1370 }
1371
1372 request.messages.push(LanguageModelRequestMessage {
1373 role: Role::User,
1374 content: vec![MessageContent::Text(added_user_message)],
1375 cache: false,
1376 });
1377
1378 request
1379 }
1380
1381 fn attached_tracked_files_state(
1382 &self,
1383 messages: &mut Vec<LanguageModelRequestMessage>,
1384 cx: &App,
1385 ) {
1386 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1387
1388 let mut stale_message = String::new();
1389
1390 let action_log = self.action_log.read(cx);
1391
1392 for stale_file in action_log.stale_buffers(cx) {
1393 let Some(file) = stale_file.read(cx).file() else {
1394 continue;
1395 };
1396
1397 if stale_message.is_empty() {
1398 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1399 }
1400
1401 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1402 }
1403
1404 let mut content = Vec::with_capacity(2);
1405
1406 if !stale_message.is_empty() {
1407 content.push(stale_message.into());
1408 }
1409
1410 if !content.is_empty() {
1411 let context_message = LanguageModelRequestMessage {
1412 role: Role::User,
1413 content,
1414 cache: false,
1415 };
1416
1417 messages.push(context_message);
1418 }
1419 }
1420
1421 pub fn stream_completion(
1422 &mut self,
1423 request: LanguageModelRequest,
1424 model: Arc<dyn LanguageModel>,
1425 window: Option<AnyWindowHandle>,
1426 cx: &mut Context<Self>,
1427 ) {
1428 self.tool_use_limit_reached = false;
1429
1430 let pending_completion_id = post_inc(&mut self.completion_count);
1431 let mut request_callback_parameters = if self.request_callback.is_some() {
1432 Some((request.clone(), Vec::new()))
1433 } else {
1434 None
1435 };
1436 let prompt_id = self.last_prompt_id.clone();
1437 let tool_use_metadata = ToolUseMetadata {
1438 model: model.clone(),
1439 thread_id: self.id.clone(),
1440 prompt_id: prompt_id.clone(),
1441 };
1442
1443 self.last_received_chunk_at = Some(Instant::now());
1444
1445 let task = cx.spawn(async move |thread, cx| {
1446 let stream_completion_future = model.stream_completion(request, &cx);
1447 let initial_token_usage =
1448 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1449 let stream_completion = async {
1450 let mut events = stream_completion_future.await?;
1451
1452 let mut stop_reason = StopReason::EndTurn;
1453 let mut current_token_usage = TokenUsage::default();
1454
1455 thread
1456 .update(cx, |_thread, cx| {
1457 cx.emit(ThreadEvent::NewRequest);
1458 })
1459 .ok();
1460
1461 let mut request_assistant_message_id = None;
1462
1463 while let Some(event) = events.next().await {
1464 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1465 response_events
1466 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1467 }
1468
1469 thread.update(cx, |thread, cx| {
1470 let event = match event {
1471 Ok(event) => event,
1472 Err(LanguageModelCompletionError::BadInputJson {
1473 id,
1474 tool_name,
1475 raw_input: invalid_input_json,
1476 json_parse_error,
1477 }) => {
1478 thread.receive_invalid_tool_json(
1479 id,
1480 tool_name,
1481 invalid_input_json,
1482 json_parse_error,
1483 window,
1484 cx,
1485 );
1486 return Ok(());
1487 }
1488 Err(LanguageModelCompletionError::Other(error)) => {
1489 return Err(error);
1490 }
1491 };
1492
1493 match event {
1494 LanguageModelCompletionEvent::StartMessage { .. } => {
1495 request_assistant_message_id =
1496 Some(thread.insert_assistant_message(
1497 vec![MessageSegment::Text(String::new())],
1498 cx,
1499 ));
1500 }
1501 LanguageModelCompletionEvent::Stop(reason) => {
1502 stop_reason = reason;
1503 }
1504 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1505 thread.update_token_usage_at_last_message(token_usage);
1506 thread.cumulative_token_usage = thread.cumulative_token_usage
1507 + token_usage
1508 - current_token_usage;
1509 current_token_usage = token_usage;
1510 }
1511 LanguageModelCompletionEvent::Text(chunk) => {
1512 thread.received_chunk();
1513
1514 cx.emit(ThreadEvent::ReceivedTextChunk);
1515 if let Some(last_message) = thread.messages.last_mut() {
1516 if last_message.role == Role::Assistant
1517 && !thread.tool_use.has_tool_results(last_message.id)
1518 {
1519 last_message.push_text(&chunk);
1520 cx.emit(ThreadEvent::StreamedAssistantText(
1521 last_message.id,
1522 chunk,
1523 ));
1524 } else {
1525 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1526 // of a new Assistant response.
1527 //
1528 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1529 // will result in duplicating the text of the chunk in the rendered Markdown.
1530 request_assistant_message_id =
1531 Some(thread.insert_assistant_message(
1532 vec![MessageSegment::Text(chunk.to_string())],
1533 cx,
1534 ));
1535 };
1536 }
1537 }
1538 LanguageModelCompletionEvent::Thinking {
1539 text: chunk,
1540 signature,
1541 } => {
1542 thread.received_chunk();
1543
1544 if let Some(last_message) = thread.messages.last_mut() {
1545 if last_message.role == Role::Assistant
1546 && !thread.tool_use.has_tool_results(last_message.id)
1547 {
1548 last_message.push_thinking(&chunk, signature);
1549 cx.emit(ThreadEvent::StreamedAssistantThinking(
1550 last_message.id,
1551 chunk,
1552 ));
1553 } else {
1554 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1555 // of a new Assistant response.
1556 //
1557 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1558 // will result in duplicating the text of the chunk in the rendered Markdown.
1559 request_assistant_message_id =
1560 Some(thread.insert_assistant_message(
1561 vec![MessageSegment::Thinking {
1562 text: chunk.to_string(),
1563 signature,
1564 }],
1565 cx,
1566 ));
1567 };
1568 }
1569 }
1570 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1571 let last_assistant_message_id = request_assistant_message_id
1572 .unwrap_or_else(|| {
1573 let new_assistant_message_id =
1574 thread.insert_assistant_message(vec![], cx);
1575 request_assistant_message_id =
1576 Some(new_assistant_message_id);
1577 new_assistant_message_id
1578 });
1579
1580 let tool_use_id = tool_use.id.clone();
1581 let streamed_input = if tool_use.is_input_complete {
1582 None
1583 } else {
1584 Some((&tool_use.input).clone())
1585 };
1586
1587 let ui_text = thread.tool_use.request_tool_use(
1588 last_assistant_message_id,
1589 tool_use,
1590 tool_use_metadata.clone(),
1591 cx,
1592 );
1593
1594 if let Some(input) = streamed_input {
1595 cx.emit(ThreadEvent::StreamedToolUse {
1596 tool_use_id,
1597 ui_text,
1598 input,
1599 });
1600 }
1601 }
1602 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1603 if let Some(completion) = thread
1604 .pending_completions
1605 .iter_mut()
1606 .find(|completion| completion.id == pending_completion_id)
1607 {
1608 match status_update {
1609 CompletionRequestStatus::Queued {
1610 position,
1611 } => {
1612 completion.queue_state = QueueState::Queued { position };
1613 }
1614 CompletionRequestStatus::Started => {
1615 completion.queue_state = QueueState::Started;
1616 }
1617 CompletionRequestStatus::Failed {
1618 code, message, request_id
1619 } => {
1620 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1621 }
1622 CompletionRequestStatus::UsageUpdated {
1623 amount, limit
1624 } => {
1625 let usage = RequestUsage { limit, amount: amount as i32 };
1626
1627 thread.last_usage = Some(usage);
1628 }
1629 CompletionRequestStatus::ToolUseLimitReached => {
1630 thread.tool_use_limit_reached = true;
1631 }
1632 }
1633 }
1634 }
1635 }
1636
1637 thread.touch_updated_at();
1638 cx.emit(ThreadEvent::StreamedCompletion);
1639 cx.notify();
1640
1641 thread.auto_capture_telemetry(cx);
1642 Ok(())
1643 })??;
1644
1645 smol::future::yield_now().await;
1646 }
1647
1648 thread.update(cx, |thread, cx| {
1649 thread.last_received_chunk_at = None;
1650 thread
1651 .pending_completions
1652 .retain(|completion| completion.id != pending_completion_id);
1653
1654 // If there is a response without tool use, summarize the message. Otherwise,
1655 // allow two tool uses before summarizing.
1656 if matches!(thread.summary, ThreadSummary::Pending)
1657 && thread.messages.len() >= 2
1658 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1659 {
1660 thread.summarize(cx);
1661 }
1662 })?;
1663
1664 anyhow::Ok(stop_reason)
1665 };
1666
1667 let result = stream_completion.await;
1668
1669 thread
1670 .update(cx, |thread, cx| {
1671 thread.finalize_pending_checkpoint(cx);
1672 match result.as_ref() {
1673 Ok(stop_reason) => match stop_reason {
1674 StopReason::ToolUse => {
1675 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1676 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1677 }
1678 StopReason::EndTurn | StopReason::MaxTokens => {
1679 thread.project.update(cx, |project, cx| {
1680 project.set_agent_location(None, cx);
1681 });
1682 }
1683 StopReason::Refusal => {
1684 thread.project.update(cx, |project, cx| {
1685 project.set_agent_location(None, cx);
1686 });
1687
1688 // Remove the turn that was refused.
1689 //
1690 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1691 {
1692 let mut messages_to_remove = Vec::new();
1693
1694 for (ix, message) in thread.messages.iter().enumerate().rev() {
1695 messages_to_remove.push(message.id);
1696
1697 if message.role == Role::User {
1698 if ix == 0 {
1699 break;
1700 }
1701
1702 if let Some(prev_message) = thread.messages.get(ix - 1) {
1703 if prev_message.role == Role::Assistant {
1704 break;
1705 }
1706 }
1707 }
1708 }
1709
1710 for message_id in messages_to_remove {
1711 thread.delete_message(message_id, cx);
1712 }
1713 }
1714
1715 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1716 header: "Language model refusal".into(),
1717 message: "Model refused to generate content for safety reasons.".into(),
1718 }));
1719 }
1720 },
1721 Err(error) => {
1722 thread.project.update(cx, |project, cx| {
1723 project.set_agent_location(None, cx);
1724 });
1725
1726 if error.is::<PaymentRequiredError>() {
1727 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1728 } else if let Some(error) =
1729 error.downcast_ref::<ModelRequestLimitReachedError>()
1730 {
1731 cx.emit(ThreadEvent::ShowError(
1732 ThreadError::ModelRequestLimitReached { plan: error.plan },
1733 ));
1734 } else if let Some(known_error) =
1735 error.downcast_ref::<LanguageModelKnownError>()
1736 {
1737 match known_error {
1738 LanguageModelKnownError::ContextWindowLimitExceeded {
1739 tokens,
1740 } => {
1741 thread.exceeded_window_error = Some(ExceededWindowError {
1742 model_id: model.id(),
1743 token_count: *tokens,
1744 });
1745 cx.notify();
1746 }
1747 }
1748 } else {
1749 let error_message = error
1750 .chain()
1751 .map(|err| err.to_string())
1752 .collect::<Vec<_>>()
1753 .join("\n");
1754 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1755 header: "Error interacting with language model".into(),
1756 message: SharedString::from(error_message.clone()),
1757 }));
1758 }
1759
1760 thread.cancel_last_completion(window, cx);
1761 }
1762 }
1763 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1764
1765 if let Some((request_callback, (request, response_events))) = thread
1766 .request_callback
1767 .as_mut()
1768 .zip(request_callback_parameters.as_ref())
1769 {
1770 request_callback(request, response_events);
1771 }
1772
1773 thread.auto_capture_telemetry(cx);
1774
1775 if let Ok(initial_usage) = initial_token_usage {
1776 let usage = thread.cumulative_token_usage - initial_usage;
1777
1778 telemetry::event!(
1779 "Assistant Thread Completion",
1780 thread_id = thread.id().to_string(),
1781 prompt_id = prompt_id,
1782 model = model.telemetry_id(),
1783 model_provider = model.provider_id().to_string(),
1784 input_tokens = usage.input_tokens,
1785 output_tokens = usage.output_tokens,
1786 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1787 cache_read_input_tokens = usage.cache_read_input_tokens,
1788 );
1789 }
1790 })
1791 .ok();
1792 });
1793
1794 self.pending_completions.push(PendingCompletion {
1795 id: pending_completion_id,
1796 queue_state: QueueState::Sending,
1797 _task: task,
1798 });
1799 }
1800
1801 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1802 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1803 println!("No thread summary model");
1804 return;
1805 };
1806
1807 if !model.provider.is_authenticated(cx) {
1808 return;
1809 }
1810
1811 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1812 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1813 If the conversation is about a specific subject, include it in the title. \
1814 Be descriptive. DO NOT speak in the first person.";
1815
1816 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1817
1818 self.summary = ThreadSummary::Generating;
1819
1820 self.pending_summary = cx.spawn(async move |this, cx| {
1821 let result = async {
1822 let mut messages = model.model.stream_completion(request, &cx).await?;
1823
1824 let mut new_summary = String::new();
1825 while let Some(event) = messages.next().await {
1826 let Ok(event) = event else {
1827 continue;
1828 };
1829 let text = match event {
1830 LanguageModelCompletionEvent::Text(text) => text,
1831 LanguageModelCompletionEvent::StatusUpdate(
1832 CompletionRequestStatus::UsageUpdated { amount, limit },
1833 ) => {
1834 this.update(cx, |thread, _cx| {
1835 thread.last_usage = Some(RequestUsage {
1836 limit,
1837 amount: amount as i32,
1838 });
1839 })?;
1840 continue;
1841 }
1842 _ => continue,
1843 };
1844
1845 let mut lines = text.lines();
1846 new_summary.extend(lines.next());
1847
1848 // Stop if the LLM generated multiple lines.
1849 if lines.next().is_some() {
1850 break;
1851 }
1852 }
1853
1854 anyhow::Ok(new_summary)
1855 }
1856 .await;
1857
1858 this.update(cx, |this, cx| {
1859 match result {
1860 Ok(new_summary) => {
1861 if new_summary.is_empty() {
1862 this.summary = ThreadSummary::Error;
1863 } else {
1864 this.summary = ThreadSummary::Ready(new_summary.into());
1865 }
1866 }
1867 Err(err) => {
1868 this.summary = ThreadSummary::Error;
1869 log::error!("Failed to generate thread summary: {}", err);
1870 }
1871 }
1872 cx.emit(ThreadEvent::SummaryGenerated);
1873 })
1874 .log_err()?;
1875
1876 Some(())
1877 });
1878 }
1879
1880 pub fn start_generating_detailed_summary_if_needed(
1881 &mut self,
1882 thread_store: WeakEntity<ThreadStore>,
1883 cx: &mut Context<Self>,
1884 ) {
1885 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1886 return;
1887 };
1888
1889 match &*self.detailed_summary_rx.borrow() {
1890 DetailedSummaryState::Generating { message_id, .. }
1891 | DetailedSummaryState::Generated { message_id, .. }
1892 if *message_id == last_message_id =>
1893 {
1894 // Already up-to-date
1895 return;
1896 }
1897 _ => {}
1898 }
1899
1900 let Some(ConfiguredModel { model, provider }) =
1901 LanguageModelRegistry::read_global(cx).thread_summary_model()
1902 else {
1903 return;
1904 };
1905
1906 if !provider.is_authenticated(cx) {
1907 return;
1908 }
1909
1910 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1911 1. A brief overview of what was discussed\n\
1912 2. Key facts or information discovered\n\
1913 3. Outcomes or conclusions reached\n\
1914 4. Any action items or next steps if any\n\
1915 Format it in Markdown with headings and bullet points.";
1916
1917 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1918
1919 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1920 message_id: last_message_id,
1921 };
1922
1923 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1924 // be better to allow the old task to complete, but this would require logic for choosing
1925 // which result to prefer (the old task could complete after the new one, resulting in a
1926 // stale summary).
1927 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1928 let stream = model.stream_completion_text(request, &cx);
1929 let Some(mut messages) = stream.await.log_err() else {
1930 thread
1931 .update(cx, |thread, _cx| {
1932 *thread.detailed_summary_tx.borrow_mut() =
1933 DetailedSummaryState::NotGenerated;
1934 })
1935 .ok()?;
1936 return None;
1937 };
1938
1939 let mut new_detailed_summary = String::new();
1940
1941 while let Some(chunk) = messages.stream.next().await {
1942 if let Some(chunk) = chunk.log_err() {
1943 new_detailed_summary.push_str(&chunk);
1944 }
1945 }
1946
1947 thread
1948 .update(cx, |thread, _cx| {
1949 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1950 text: new_detailed_summary.into(),
1951 message_id: last_message_id,
1952 };
1953 })
1954 .ok()?;
1955
1956 // Save thread so its summary can be reused later
1957 if let Some(thread) = thread.upgrade() {
1958 if let Ok(Ok(save_task)) = cx.update(|cx| {
1959 thread_store
1960 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1961 }) {
1962 save_task.await.log_err();
1963 }
1964 }
1965
1966 Some(())
1967 });
1968 }
1969
1970 pub async fn wait_for_detailed_summary_or_text(
1971 this: &Entity<Self>,
1972 cx: &mut AsyncApp,
1973 ) -> Option<SharedString> {
1974 let mut detailed_summary_rx = this
1975 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1976 .ok()?;
1977 loop {
1978 match detailed_summary_rx.recv().await? {
1979 DetailedSummaryState::Generating { .. } => {}
1980 DetailedSummaryState::NotGenerated => {
1981 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1982 }
1983 DetailedSummaryState::Generated { text, .. } => return Some(text),
1984 }
1985 }
1986 }
1987
1988 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1989 self.detailed_summary_rx
1990 .borrow()
1991 .text()
1992 .unwrap_or_else(|| self.text().into())
1993 }
1994
1995 pub fn is_generating_detailed_summary(&self) -> bool {
1996 matches!(
1997 &*self.detailed_summary_rx.borrow(),
1998 DetailedSummaryState::Generating { .. }
1999 )
2000 }
2001
2002 pub fn use_pending_tools(
2003 &mut self,
2004 window: Option<AnyWindowHandle>,
2005 cx: &mut Context<Self>,
2006 model: Arc<dyn LanguageModel>,
2007 ) -> Vec<PendingToolUse> {
2008 self.auto_capture_telemetry(cx);
2009 let request = Arc::new(self.to_completion_request(model.clone(), cx));
2010 let pending_tool_uses = self
2011 .tool_use
2012 .pending_tool_uses()
2013 .into_iter()
2014 .filter(|tool_use| tool_use.status.is_idle())
2015 .cloned()
2016 .collect::<Vec<_>>();
2017
2018 for tool_use in pending_tool_uses.iter() {
2019 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2020 if tool.needs_confirmation(&tool_use.input, cx)
2021 && !AssistantSettings::get_global(cx).always_allow_tool_actions
2022 {
2023 self.tool_use.confirm_tool_use(
2024 tool_use.id.clone(),
2025 tool_use.ui_text.clone(),
2026 tool_use.input.clone(),
2027 request.clone(),
2028 tool,
2029 );
2030 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2031 } else {
2032 self.run_tool(
2033 tool_use.id.clone(),
2034 tool_use.ui_text.clone(),
2035 tool_use.input.clone(),
2036 request.clone(),
2037 tool,
2038 model.clone(),
2039 window,
2040 cx,
2041 );
2042 }
2043 } else {
2044 self.handle_hallucinated_tool_use(
2045 tool_use.id.clone(),
2046 tool_use.name.clone(),
2047 window,
2048 cx,
2049 );
2050 }
2051 }
2052
2053 pending_tool_uses
2054 }
2055
2056 pub fn handle_hallucinated_tool_use(
2057 &mut self,
2058 tool_use_id: LanguageModelToolUseId,
2059 hallucinated_tool_name: Arc<str>,
2060 window: Option<AnyWindowHandle>,
2061 cx: &mut Context<Thread>,
2062 ) {
2063 let available_tools = self.tools.read(cx).enabled_tools(cx);
2064
2065 let tool_list = available_tools
2066 .iter()
2067 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2068 .collect::<Vec<_>>()
2069 .join("\n");
2070
2071 let error_message = format!(
2072 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2073 hallucinated_tool_name, tool_list
2074 );
2075
2076 let pending_tool_use = self.tool_use.insert_tool_output(
2077 tool_use_id.clone(),
2078 hallucinated_tool_name,
2079 Err(anyhow!("Missing tool call: {error_message}")),
2080 self.configured_model.as_ref(),
2081 );
2082
2083 cx.emit(ThreadEvent::MissingToolUse {
2084 tool_use_id: tool_use_id.clone(),
2085 ui_text: error_message.into(),
2086 });
2087
2088 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2089 }
2090
2091 pub fn receive_invalid_tool_json(
2092 &mut self,
2093 tool_use_id: LanguageModelToolUseId,
2094 tool_name: Arc<str>,
2095 invalid_json: Arc<str>,
2096 error: String,
2097 window: Option<AnyWindowHandle>,
2098 cx: &mut Context<Thread>,
2099 ) {
2100 log::error!("The model returned invalid input JSON: {invalid_json}");
2101
2102 let pending_tool_use = self.tool_use.insert_tool_output(
2103 tool_use_id.clone(),
2104 tool_name,
2105 Err(anyhow!("Error parsing input JSON: {error}")),
2106 self.configured_model.as_ref(),
2107 );
2108 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2109 pending_tool_use.ui_text.clone()
2110 } else {
2111 log::error!(
2112 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2113 );
2114 format!("Unknown tool {}", tool_use_id).into()
2115 };
2116
2117 cx.emit(ThreadEvent::InvalidToolInput {
2118 tool_use_id: tool_use_id.clone(),
2119 ui_text,
2120 invalid_input_json: invalid_json,
2121 });
2122
2123 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2124 }
2125
2126 pub fn run_tool(
2127 &mut self,
2128 tool_use_id: LanguageModelToolUseId,
2129 ui_text: impl Into<SharedString>,
2130 input: serde_json::Value,
2131 request: Arc<LanguageModelRequest>,
2132 tool: Arc<dyn Tool>,
2133 model: Arc<dyn LanguageModel>,
2134 window: Option<AnyWindowHandle>,
2135 cx: &mut Context<Thread>,
2136 ) {
2137 let task =
2138 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2139 self.tool_use
2140 .run_pending_tool(tool_use_id, ui_text.into(), task);
2141 }
2142
2143 fn spawn_tool_use(
2144 &mut self,
2145 tool_use_id: LanguageModelToolUseId,
2146 request: Arc<LanguageModelRequest>,
2147 input: serde_json::Value,
2148 tool: Arc<dyn Tool>,
2149 model: Arc<dyn LanguageModel>,
2150 window: Option<AnyWindowHandle>,
2151 cx: &mut Context<Thread>,
2152 ) -> Task<()> {
2153 let tool_name: Arc<str> = tool.name().into();
2154
2155 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2156 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2157 } else {
2158 tool.run(
2159 input,
2160 request,
2161 self.project.clone(),
2162 self.action_log.clone(),
2163 model,
2164 window,
2165 cx,
2166 )
2167 };
2168
2169 // Store the card separately if it exists
2170 if let Some(card) = tool_result.card.clone() {
2171 self.tool_use
2172 .insert_tool_result_card(tool_use_id.clone(), card);
2173 }
2174
2175 cx.spawn({
2176 async move |thread: WeakEntity<Thread>, cx| {
2177 let output = tool_result.output.await;
2178
2179 thread
2180 .update(cx, |thread, cx| {
2181 let pending_tool_use = thread.tool_use.insert_tool_output(
2182 tool_use_id.clone(),
2183 tool_name,
2184 output,
2185 thread.configured_model.as_ref(),
2186 );
2187 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2188 })
2189 .ok();
2190 }
2191 })
2192 }
2193
2194 fn tool_finished(
2195 &mut self,
2196 tool_use_id: LanguageModelToolUseId,
2197 pending_tool_use: Option<PendingToolUse>,
2198 canceled: bool,
2199 window: Option<AnyWindowHandle>,
2200 cx: &mut Context<Self>,
2201 ) {
2202 if self.all_tools_finished() {
2203 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2204 if !canceled {
2205 self.send_to_model(model.clone(), window, cx);
2206 }
2207 self.auto_capture_telemetry(cx);
2208 }
2209 }
2210
2211 cx.emit(ThreadEvent::ToolFinished {
2212 tool_use_id,
2213 pending_tool_use,
2214 });
2215 }
2216
2217 /// Cancels the last pending completion, if there are any pending.
2218 ///
2219 /// Returns whether a completion was canceled.
2220 pub fn cancel_last_completion(
2221 &mut self,
2222 window: Option<AnyWindowHandle>,
2223 cx: &mut Context<Self>,
2224 ) -> bool {
2225 let mut canceled = self.pending_completions.pop().is_some();
2226
2227 for pending_tool_use in self.tool_use.cancel_pending() {
2228 canceled = true;
2229 self.tool_finished(
2230 pending_tool_use.id.clone(),
2231 Some(pending_tool_use),
2232 true,
2233 window,
2234 cx,
2235 );
2236 }
2237
2238 self.finalize_pending_checkpoint(cx);
2239
2240 if canceled {
2241 cx.emit(ThreadEvent::CompletionCanceled);
2242 }
2243
2244 canceled
2245 }
2246
2247 /// Signals that any in-progress editing should be canceled.
2248 ///
2249 /// This method is used to notify listeners (like ActiveThread) that
2250 /// they should cancel any editing operations.
2251 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2252 cx.emit(ThreadEvent::CancelEditing);
2253 }
2254
2255 pub fn feedback(&self) -> Option<ThreadFeedback> {
2256 self.feedback
2257 }
2258
2259 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2260 self.message_feedback.get(&message_id).copied()
2261 }
2262
2263 pub fn report_message_feedback(
2264 &mut self,
2265 message_id: MessageId,
2266 feedback: ThreadFeedback,
2267 cx: &mut Context<Self>,
2268 ) -> Task<Result<()>> {
2269 if self.message_feedback.get(&message_id) == Some(&feedback) {
2270 return Task::ready(Ok(()));
2271 }
2272
2273 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2274 let serialized_thread = self.serialize(cx);
2275 let thread_id = self.id().clone();
2276 let client = self.project.read(cx).client();
2277
2278 let enabled_tool_names: Vec<String> = self
2279 .tools()
2280 .read(cx)
2281 .enabled_tools(cx)
2282 .iter()
2283 .map(|tool| tool.name())
2284 .collect();
2285
2286 self.message_feedback.insert(message_id, feedback);
2287
2288 cx.notify();
2289
2290 let message_content = self
2291 .message(message_id)
2292 .map(|msg| msg.to_string())
2293 .unwrap_or_default();
2294
2295 cx.background_spawn(async move {
2296 let final_project_snapshot = final_project_snapshot.await;
2297 let serialized_thread = serialized_thread.await?;
2298 let thread_data =
2299 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2300
2301 let rating = match feedback {
2302 ThreadFeedback::Positive => "positive",
2303 ThreadFeedback::Negative => "negative",
2304 };
2305 telemetry::event!(
2306 "Assistant Thread Rated",
2307 rating,
2308 thread_id,
2309 enabled_tool_names,
2310 message_id = message_id.0,
2311 message_content,
2312 thread_data,
2313 final_project_snapshot
2314 );
2315 client.telemetry().flush_events().await;
2316
2317 Ok(())
2318 })
2319 }
2320
2321 pub fn report_feedback(
2322 &mut self,
2323 feedback: ThreadFeedback,
2324 cx: &mut Context<Self>,
2325 ) -> Task<Result<()>> {
2326 let last_assistant_message_id = self
2327 .messages
2328 .iter()
2329 .rev()
2330 .find(|msg| msg.role == Role::Assistant)
2331 .map(|msg| msg.id);
2332
2333 if let Some(message_id) = last_assistant_message_id {
2334 self.report_message_feedback(message_id, feedback, cx)
2335 } else {
2336 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2337 let serialized_thread = self.serialize(cx);
2338 let thread_id = self.id().clone();
2339 let client = self.project.read(cx).client();
2340 self.feedback = Some(feedback);
2341 cx.notify();
2342
2343 cx.background_spawn(async move {
2344 let final_project_snapshot = final_project_snapshot.await;
2345 let serialized_thread = serialized_thread.await?;
2346 let thread_data = serde_json::to_value(serialized_thread)
2347 .unwrap_or_else(|_| serde_json::Value::Null);
2348
2349 let rating = match feedback {
2350 ThreadFeedback::Positive => "positive",
2351 ThreadFeedback::Negative => "negative",
2352 };
2353 telemetry::event!(
2354 "Assistant Thread Rated",
2355 rating,
2356 thread_id,
2357 thread_data,
2358 final_project_snapshot
2359 );
2360 client.telemetry().flush_events().await;
2361
2362 Ok(())
2363 })
2364 }
2365 }
2366
2367 /// Create a snapshot of the current project state including git information and unsaved buffers.
2368 fn project_snapshot(
2369 project: Entity<Project>,
2370 cx: &mut Context<Self>,
2371 ) -> Task<Arc<ProjectSnapshot>> {
2372 let git_store = project.read(cx).git_store().clone();
2373 let worktree_snapshots: Vec<_> = project
2374 .read(cx)
2375 .visible_worktrees(cx)
2376 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2377 .collect();
2378
2379 cx.spawn(async move |_, cx| {
2380 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2381
2382 let mut unsaved_buffers = Vec::new();
2383 cx.update(|app_cx| {
2384 let buffer_store = project.read(app_cx).buffer_store();
2385 for buffer_handle in buffer_store.read(app_cx).buffers() {
2386 let buffer = buffer_handle.read(app_cx);
2387 if buffer.is_dirty() {
2388 if let Some(file) = buffer.file() {
2389 let path = file.path().to_string_lossy().to_string();
2390 unsaved_buffers.push(path);
2391 }
2392 }
2393 }
2394 })
2395 .ok();
2396
2397 Arc::new(ProjectSnapshot {
2398 worktree_snapshots,
2399 unsaved_buffer_paths: unsaved_buffers,
2400 timestamp: Utc::now(),
2401 })
2402 })
2403 }
2404
2405 fn worktree_snapshot(
2406 worktree: Entity<project::Worktree>,
2407 git_store: Entity<GitStore>,
2408 cx: &App,
2409 ) -> Task<WorktreeSnapshot> {
2410 cx.spawn(async move |cx| {
2411 // Get worktree path and snapshot
2412 let worktree_info = cx.update(|app_cx| {
2413 let worktree = worktree.read(app_cx);
2414 let path = worktree.abs_path().to_string_lossy().to_string();
2415 let snapshot = worktree.snapshot();
2416 (path, snapshot)
2417 });
2418
2419 let Ok((worktree_path, _snapshot)) = worktree_info else {
2420 return WorktreeSnapshot {
2421 worktree_path: String::new(),
2422 git_state: None,
2423 };
2424 };
2425
2426 let git_state = git_store
2427 .update(cx, |git_store, cx| {
2428 git_store
2429 .repositories()
2430 .values()
2431 .find(|repo| {
2432 repo.read(cx)
2433 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2434 .is_some()
2435 })
2436 .cloned()
2437 })
2438 .ok()
2439 .flatten()
2440 .map(|repo| {
2441 repo.update(cx, |repo, _| {
2442 let current_branch =
2443 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2444 repo.send_job(None, |state, _| async move {
2445 let RepositoryState::Local { backend, .. } = state else {
2446 return GitState {
2447 remote_url: None,
2448 head_sha: None,
2449 current_branch,
2450 diff: None,
2451 };
2452 };
2453
2454 let remote_url = backend.remote_url("origin");
2455 let head_sha = backend.head_sha().await;
2456 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2457
2458 GitState {
2459 remote_url,
2460 head_sha,
2461 current_branch,
2462 diff,
2463 }
2464 })
2465 })
2466 });
2467
2468 let git_state = match git_state {
2469 Some(git_state) => match git_state.ok() {
2470 Some(git_state) => git_state.await.ok(),
2471 None => None,
2472 },
2473 None => None,
2474 };
2475
2476 WorktreeSnapshot {
2477 worktree_path,
2478 git_state,
2479 }
2480 })
2481 }
2482
2483 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2484 let mut markdown = Vec::new();
2485
2486 let summary = self.summary().or_default();
2487 writeln!(markdown, "# {summary}\n")?;
2488
2489 for message in self.messages() {
2490 writeln!(
2491 markdown,
2492 "## {role}\n",
2493 role = match message.role {
2494 Role::User => "User",
2495 Role::Assistant => "Agent",
2496 Role::System => "System",
2497 }
2498 )?;
2499
2500 if !message.loaded_context.text.is_empty() {
2501 writeln!(markdown, "{}", message.loaded_context.text)?;
2502 }
2503
2504 if !message.loaded_context.images.is_empty() {
2505 writeln!(
2506 markdown,
2507 "\n{} images attached as context.\n",
2508 message.loaded_context.images.len()
2509 )?;
2510 }
2511
2512 for segment in &message.segments {
2513 match segment {
2514 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2515 MessageSegment::Thinking { text, .. } => {
2516 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2517 }
2518 MessageSegment::RedactedThinking(_) => {}
2519 }
2520 }
2521
2522 for tool_use in self.tool_uses_for_message(message.id, cx) {
2523 writeln!(
2524 markdown,
2525 "**Use Tool: {} ({})**",
2526 tool_use.name, tool_use.id
2527 )?;
2528 writeln!(markdown, "```json")?;
2529 writeln!(
2530 markdown,
2531 "{}",
2532 serde_json::to_string_pretty(&tool_use.input)?
2533 )?;
2534 writeln!(markdown, "```")?;
2535 }
2536
2537 for tool_result in self.tool_results_for_message(message.id) {
2538 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2539 if tool_result.is_error {
2540 write!(markdown, " (Error)")?;
2541 }
2542
2543 writeln!(markdown, "**\n")?;
2544 match &tool_result.content {
2545 LanguageModelToolResultContent::Text(str) => {
2546 writeln!(markdown, "{}", str)?;
2547 }
2548 LanguageModelToolResultContent::Image(image) => {
2549 writeln!(markdown, "", image.source)?;
2550 }
2551 }
2552
2553 if let Some(output) = tool_result.output.as_ref() {
2554 writeln!(
2555 markdown,
2556 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2557 serde_json::to_string_pretty(output)?
2558 )?;
2559 }
2560 }
2561 }
2562
2563 Ok(String::from_utf8_lossy(&markdown).to_string())
2564 }
2565
2566 pub fn keep_edits_in_range(
2567 &mut self,
2568 buffer: Entity<language::Buffer>,
2569 buffer_range: Range<language::Anchor>,
2570 cx: &mut Context<Self>,
2571 ) {
2572 self.action_log.update(cx, |action_log, cx| {
2573 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2574 });
2575 }
2576
2577 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2578 self.action_log
2579 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2580 }
2581
2582 pub fn reject_edits_in_ranges(
2583 &mut self,
2584 buffer: Entity<language::Buffer>,
2585 buffer_ranges: Vec<Range<language::Anchor>>,
2586 cx: &mut Context<Self>,
2587 ) -> Task<Result<()>> {
2588 self.action_log.update(cx, |action_log, cx| {
2589 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2590 })
2591 }
2592
2593 pub fn action_log(&self) -> &Entity<ActionLog> {
2594 &self.action_log
2595 }
2596
2597 pub fn project(&self) -> &Entity<Project> {
2598 &self.project
2599 }
2600
2601 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2602 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2603 return;
2604 }
2605
2606 let now = Instant::now();
2607 if let Some(last) = self.last_auto_capture_at {
2608 if now.duration_since(last).as_secs() < 10 {
2609 return;
2610 }
2611 }
2612
2613 self.last_auto_capture_at = Some(now);
2614
2615 let thread_id = self.id().clone();
2616 let github_login = self
2617 .project
2618 .read(cx)
2619 .user_store()
2620 .read(cx)
2621 .current_user()
2622 .map(|user| user.github_login.clone());
2623 let client = self.project.read(cx).client().clone();
2624 let serialize_task = self.serialize(cx);
2625
2626 cx.background_executor()
2627 .spawn(async move {
2628 if let Ok(serialized_thread) = serialize_task.await {
2629 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2630 telemetry::event!(
2631 "Agent Thread Auto-Captured",
2632 thread_id = thread_id.to_string(),
2633 thread_data = thread_data,
2634 auto_capture_reason = "tracked_user",
2635 github_login = github_login
2636 );
2637
2638 client.telemetry().flush_events().await;
2639 }
2640 }
2641 })
2642 .detach();
2643 }
2644
2645 pub fn cumulative_token_usage(&self) -> TokenUsage {
2646 self.cumulative_token_usage
2647 }
2648
2649 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2650 let Some(model) = self.configured_model.as_ref() else {
2651 return TotalTokenUsage::default();
2652 };
2653
2654 let max = model.model.max_token_count();
2655
2656 let index = self
2657 .messages
2658 .iter()
2659 .position(|msg| msg.id == message_id)
2660 .unwrap_or(0);
2661
2662 if index == 0 {
2663 return TotalTokenUsage { total: 0, max };
2664 }
2665
2666 let token_usage = &self
2667 .request_token_usage
2668 .get(index - 1)
2669 .cloned()
2670 .unwrap_or_default();
2671
2672 TotalTokenUsage {
2673 total: token_usage.total_tokens() as usize,
2674 max,
2675 }
2676 }
2677
2678 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2679 let model = self.configured_model.as_ref()?;
2680
2681 let max = model.model.max_token_count();
2682
2683 if let Some(exceeded_error) = &self.exceeded_window_error {
2684 if model.model.id() == exceeded_error.model_id {
2685 return Some(TotalTokenUsage {
2686 total: exceeded_error.token_count,
2687 max,
2688 });
2689 }
2690 }
2691
2692 let total = self
2693 .token_usage_at_last_message()
2694 .unwrap_or_default()
2695 .total_tokens() as usize;
2696
2697 Some(TotalTokenUsage { total, max })
2698 }
2699
2700 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2701 self.request_token_usage
2702 .get(self.messages.len().saturating_sub(1))
2703 .or_else(|| self.request_token_usage.last())
2704 .cloned()
2705 }
2706
2707 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2708 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2709 self.request_token_usage
2710 .resize(self.messages.len(), placeholder);
2711
2712 if let Some(last) = self.request_token_usage.last_mut() {
2713 *last = token_usage;
2714 }
2715 }
2716
2717 pub fn deny_tool_use(
2718 &mut self,
2719 tool_use_id: LanguageModelToolUseId,
2720 tool_name: Arc<str>,
2721 window: Option<AnyWindowHandle>,
2722 cx: &mut Context<Self>,
2723 ) {
2724 let err = Err(anyhow::anyhow!(
2725 "Permission to run tool action denied by user"
2726 ));
2727
2728 self.tool_use.insert_tool_output(
2729 tool_use_id.clone(),
2730 tool_name,
2731 err,
2732 self.configured_model.as_ref(),
2733 );
2734 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2735 }
2736}
2737
2738#[derive(Debug, Clone, Error)]
2739pub enum ThreadError {
2740 #[error("Payment required")]
2741 PaymentRequired,
2742 #[error("Model request limit reached")]
2743 ModelRequestLimitReached { plan: Plan },
2744 #[error("Message {header}: {message}")]
2745 Message {
2746 header: SharedString,
2747 message: SharedString,
2748 },
2749}
2750
2751#[derive(Debug, Clone)]
2752pub enum ThreadEvent {
2753 ShowError(ThreadError),
2754 StreamedCompletion,
2755 ReceivedTextChunk,
2756 NewRequest,
2757 StreamedAssistantText(MessageId, String),
2758 StreamedAssistantThinking(MessageId, String),
2759 StreamedToolUse {
2760 tool_use_id: LanguageModelToolUseId,
2761 ui_text: Arc<str>,
2762 input: serde_json::Value,
2763 },
2764 MissingToolUse {
2765 tool_use_id: LanguageModelToolUseId,
2766 ui_text: Arc<str>,
2767 },
2768 InvalidToolInput {
2769 tool_use_id: LanguageModelToolUseId,
2770 ui_text: Arc<str>,
2771 invalid_input_json: Arc<str>,
2772 },
2773 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2774 MessageAdded(MessageId),
2775 MessageEdited(MessageId),
2776 MessageDeleted(MessageId),
2777 SummaryGenerated,
2778 SummaryChanged,
2779 UsePendingTools {
2780 tool_uses: Vec<PendingToolUse>,
2781 },
2782 ToolFinished {
2783 #[allow(unused)]
2784 tool_use_id: LanguageModelToolUseId,
2785 /// The pending tool use that corresponds to this tool.
2786 pending_tool_use: Option<PendingToolUse>,
2787 },
2788 CheckpointChanged,
2789 ToolConfirmationNeeded,
2790 CancelEditing,
2791 CompletionCanceled,
2792}
2793
2794impl EventEmitter<ThreadEvent> for Thread {}
2795
2796struct PendingCompletion {
2797 id: usize,
2798 queue_state: QueueState,
2799 _task: Task<()>,
2800}
2801
2802#[cfg(test)]
2803mod tests {
2804 use super::*;
2805 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2806 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2807 use assistant_tool::ToolRegistry;
2808 use editor::EditorSettings;
2809 use gpui::TestAppContext;
2810 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2811 use project::{FakeFs, Project};
2812 use prompt_store::PromptBuilder;
2813 use serde_json::json;
2814 use settings::{Settings, SettingsStore};
2815 use std::sync::Arc;
2816 use theme::ThemeSettings;
2817 use util::path;
2818 use workspace::Workspace;
2819
2820 #[gpui::test]
2821 async fn test_message_with_context(cx: &mut TestAppContext) {
2822 init_test_settings(cx);
2823
2824 let project = create_test_project(
2825 cx,
2826 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2827 )
2828 .await;
2829
2830 let (_workspace, _thread_store, thread, context_store, model) =
2831 setup_test_environment(cx, project.clone()).await;
2832
2833 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2834 .await
2835 .unwrap();
2836
2837 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2838 let loaded_context = cx
2839 .update(|cx| load_context(vec![context], &project, &None, cx))
2840 .await;
2841
2842 // Insert user message with context
2843 let message_id = thread.update(cx, |thread, cx| {
2844 thread.insert_user_message(
2845 "Please explain this code",
2846 loaded_context,
2847 None,
2848 Vec::new(),
2849 cx,
2850 )
2851 });
2852
2853 // Check content and context in message object
2854 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2855
2856 // Use different path format strings based on platform for the test
2857 #[cfg(windows)]
2858 let path_part = r"test\code.rs";
2859 #[cfg(not(windows))]
2860 let path_part = "test/code.rs";
2861
2862 let expected_context = format!(
2863 r#"
2864<context>
2865The following items were attached by the user. They are up-to-date and don't need to be re-read.
2866
2867<files>
2868```rs {path_part}
2869fn main() {{
2870 println!("Hello, world!");
2871}}
2872```
2873</files>
2874</context>
2875"#
2876 );
2877
2878 assert_eq!(message.role, Role::User);
2879 assert_eq!(message.segments.len(), 1);
2880 assert_eq!(
2881 message.segments[0],
2882 MessageSegment::Text("Please explain this code".to_string())
2883 );
2884 assert_eq!(message.loaded_context.text, expected_context);
2885
2886 // Check message in request
2887 let request = thread.update(cx, |thread, cx| {
2888 thread.to_completion_request(model.clone(), cx)
2889 });
2890
2891 assert_eq!(request.messages.len(), 2);
2892 let expected_full_message = format!("{}Please explain this code", expected_context);
2893 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2894 }
2895
2896 #[gpui::test]
2897 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2898 init_test_settings(cx);
2899
2900 let project = create_test_project(
2901 cx,
2902 json!({
2903 "file1.rs": "fn function1() {}\n",
2904 "file2.rs": "fn function2() {}\n",
2905 "file3.rs": "fn function3() {}\n",
2906 "file4.rs": "fn function4() {}\n",
2907 }),
2908 )
2909 .await;
2910
2911 let (_, _thread_store, thread, context_store, model) =
2912 setup_test_environment(cx, project.clone()).await;
2913
2914 // First message with context 1
2915 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2916 .await
2917 .unwrap();
2918 let new_contexts = context_store.update(cx, |store, cx| {
2919 store.new_context_for_thread(thread.read(cx), None)
2920 });
2921 assert_eq!(new_contexts.len(), 1);
2922 let loaded_context = cx
2923 .update(|cx| load_context(new_contexts, &project, &None, cx))
2924 .await;
2925 let message1_id = thread.update(cx, |thread, cx| {
2926 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2927 });
2928
2929 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2930 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2931 .await
2932 .unwrap();
2933 let new_contexts = context_store.update(cx, |store, cx| {
2934 store.new_context_for_thread(thread.read(cx), None)
2935 });
2936 assert_eq!(new_contexts.len(), 1);
2937 let loaded_context = cx
2938 .update(|cx| load_context(new_contexts, &project, &None, cx))
2939 .await;
2940 let message2_id = thread.update(cx, |thread, cx| {
2941 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2942 });
2943
2944 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2945 //
2946 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2947 .await
2948 .unwrap();
2949 let new_contexts = context_store.update(cx, |store, cx| {
2950 store.new_context_for_thread(thread.read(cx), None)
2951 });
2952 assert_eq!(new_contexts.len(), 1);
2953 let loaded_context = cx
2954 .update(|cx| load_context(new_contexts, &project, &None, cx))
2955 .await;
2956 let message3_id = thread.update(cx, |thread, cx| {
2957 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2958 });
2959
2960 // Check what contexts are included in each message
2961 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2962 (
2963 thread.message(message1_id).unwrap().clone(),
2964 thread.message(message2_id).unwrap().clone(),
2965 thread.message(message3_id).unwrap().clone(),
2966 )
2967 });
2968
2969 // First message should include context 1
2970 assert!(message1.loaded_context.text.contains("file1.rs"));
2971
2972 // Second message should include only context 2 (not 1)
2973 assert!(!message2.loaded_context.text.contains("file1.rs"));
2974 assert!(message2.loaded_context.text.contains("file2.rs"));
2975
2976 // Third message should include only context 3 (not 1 or 2)
2977 assert!(!message3.loaded_context.text.contains("file1.rs"));
2978 assert!(!message3.loaded_context.text.contains("file2.rs"));
2979 assert!(message3.loaded_context.text.contains("file3.rs"));
2980
2981 // Check entire request to make sure all contexts are properly included
2982 let request = thread.update(cx, |thread, cx| {
2983 thread.to_completion_request(model.clone(), cx)
2984 });
2985
2986 // The request should contain all 3 messages
2987 assert_eq!(request.messages.len(), 4);
2988
2989 // Check that the contexts are properly formatted in each message
2990 assert!(request.messages[1].string_contents().contains("file1.rs"));
2991 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2992 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2993
2994 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2995 assert!(request.messages[2].string_contents().contains("file2.rs"));
2996 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2997
2998 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2999 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3000 assert!(request.messages[3].string_contents().contains("file3.rs"));
3001
3002 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3003 .await
3004 .unwrap();
3005 let new_contexts = context_store.update(cx, |store, cx| {
3006 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3007 });
3008 assert_eq!(new_contexts.len(), 3);
3009 let loaded_context = cx
3010 .update(|cx| load_context(new_contexts, &project, &None, cx))
3011 .await
3012 .loaded_context;
3013
3014 assert!(!loaded_context.text.contains("file1.rs"));
3015 assert!(loaded_context.text.contains("file2.rs"));
3016 assert!(loaded_context.text.contains("file3.rs"));
3017 assert!(loaded_context.text.contains("file4.rs"));
3018
3019 let new_contexts = context_store.update(cx, |store, cx| {
3020 // Remove file4.rs
3021 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3022 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3023 });
3024 assert_eq!(new_contexts.len(), 2);
3025 let loaded_context = cx
3026 .update(|cx| load_context(new_contexts, &project, &None, cx))
3027 .await
3028 .loaded_context;
3029
3030 assert!(!loaded_context.text.contains("file1.rs"));
3031 assert!(loaded_context.text.contains("file2.rs"));
3032 assert!(loaded_context.text.contains("file3.rs"));
3033 assert!(!loaded_context.text.contains("file4.rs"));
3034
3035 let new_contexts = context_store.update(cx, |store, cx| {
3036 // Remove file3.rs
3037 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3038 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3039 });
3040 assert_eq!(new_contexts.len(), 1);
3041 let loaded_context = cx
3042 .update(|cx| load_context(new_contexts, &project, &None, cx))
3043 .await
3044 .loaded_context;
3045
3046 assert!(!loaded_context.text.contains("file1.rs"));
3047 assert!(loaded_context.text.contains("file2.rs"));
3048 assert!(!loaded_context.text.contains("file3.rs"));
3049 assert!(!loaded_context.text.contains("file4.rs"));
3050 }
3051
3052 #[gpui::test]
3053 async fn test_message_without_files(cx: &mut TestAppContext) {
3054 init_test_settings(cx);
3055
3056 let project = create_test_project(
3057 cx,
3058 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3059 )
3060 .await;
3061
3062 let (_, _thread_store, thread, _context_store, model) =
3063 setup_test_environment(cx, project.clone()).await;
3064
3065 // Insert user message without any context (empty context vector)
3066 let message_id = thread.update(cx, |thread, cx| {
3067 thread.insert_user_message(
3068 "What is the best way to learn Rust?",
3069 ContextLoadResult::default(),
3070 None,
3071 Vec::new(),
3072 cx,
3073 )
3074 });
3075
3076 // Check content and context in message object
3077 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3078
3079 // Context should be empty when no files are included
3080 assert_eq!(message.role, Role::User);
3081 assert_eq!(message.segments.len(), 1);
3082 assert_eq!(
3083 message.segments[0],
3084 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3085 );
3086 assert_eq!(message.loaded_context.text, "");
3087
3088 // Check message in request
3089 let request = thread.update(cx, |thread, cx| {
3090 thread.to_completion_request(model.clone(), cx)
3091 });
3092
3093 assert_eq!(request.messages.len(), 2);
3094 assert_eq!(
3095 request.messages[1].string_contents(),
3096 "What is the best way to learn Rust?"
3097 );
3098
3099 // Add second message, also without context
3100 let message2_id = thread.update(cx, |thread, cx| {
3101 thread.insert_user_message(
3102 "Are there any good books?",
3103 ContextLoadResult::default(),
3104 None,
3105 Vec::new(),
3106 cx,
3107 )
3108 });
3109
3110 let message2 =
3111 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3112 assert_eq!(message2.loaded_context.text, "");
3113
3114 // Check that both messages appear in the request
3115 let request = thread.update(cx, |thread, cx| {
3116 thread.to_completion_request(model.clone(), cx)
3117 });
3118
3119 assert_eq!(request.messages.len(), 3);
3120 assert_eq!(
3121 request.messages[1].string_contents(),
3122 "What is the best way to learn Rust?"
3123 );
3124 assert_eq!(
3125 request.messages[2].string_contents(),
3126 "Are there any good books?"
3127 );
3128 }
3129
3130 #[gpui::test]
3131 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3132 init_test_settings(cx);
3133
3134 let project = create_test_project(
3135 cx,
3136 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3137 )
3138 .await;
3139
3140 let (_workspace, _thread_store, thread, context_store, model) =
3141 setup_test_environment(cx, project.clone()).await;
3142
3143 // Open buffer and add it to context
3144 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3145 .await
3146 .unwrap();
3147
3148 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3149 let loaded_context = cx
3150 .update(|cx| load_context(vec![context], &project, &None, cx))
3151 .await;
3152
3153 // Insert user message with the buffer as context
3154 thread.update(cx, |thread, cx| {
3155 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3156 });
3157
3158 // Create a request and check that it doesn't have a stale buffer warning yet
3159 let initial_request = thread.update(cx, |thread, cx| {
3160 thread.to_completion_request(model.clone(), cx)
3161 });
3162
3163 // Make sure we don't have a stale file warning yet
3164 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3165 msg.string_contents()
3166 .contains("These files changed since last read:")
3167 });
3168 assert!(
3169 !has_stale_warning,
3170 "Should not have stale buffer warning before buffer is modified"
3171 );
3172
3173 // Modify the buffer
3174 buffer.update(cx, |buffer, cx| {
3175 // Find a position at the end of line 1
3176 buffer.edit(
3177 [(1..1, "\n println!(\"Added a new line\");\n")],
3178 None,
3179 cx,
3180 );
3181 });
3182
3183 // Insert another user message without context
3184 thread.update(cx, |thread, cx| {
3185 thread.insert_user_message(
3186 "What does the code do now?",
3187 ContextLoadResult::default(),
3188 None,
3189 Vec::new(),
3190 cx,
3191 )
3192 });
3193
3194 // Create a new request and check for the stale buffer warning
3195 let new_request = thread.update(cx, |thread, cx| {
3196 thread.to_completion_request(model.clone(), cx)
3197 });
3198
3199 // We should have a stale file warning as the last message
3200 let last_message = new_request
3201 .messages
3202 .last()
3203 .expect("Request should have messages");
3204
3205 // The last message should be the stale buffer notification
3206 assert_eq!(last_message.role, Role::User);
3207
3208 // Check the exact content of the message
3209 let expected_content = "These files changed since last read:\n- code.rs\n";
3210 assert_eq!(
3211 last_message.string_contents(),
3212 expected_content,
3213 "Last message should be exactly the stale buffer notification"
3214 );
3215 }
3216
3217 #[gpui::test]
3218 async fn test_temperature_setting(cx: &mut TestAppContext) {
3219 init_test_settings(cx);
3220
3221 let project = create_test_project(
3222 cx,
3223 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3224 )
3225 .await;
3226
3227 let (_workspace, _thread_store, thread, _context_store, model) =
3228 setup_test_environment(cx, project.clone()).await;
3229
3230 // Both model and provider
3231 cx.update(|cx| {
3232 AssistantSettings::override_global(
3233 AssistantSettings {
3234 model_parameters: vec![LanguageModelParameters {
3235 provider: Some(model.provider_id().0.to_string().into()),
3236 model: Some(model.id().0.clone()),
3237 temperature: Some(0.66),
3238 }],
3239 ..AssistantSettings::get_global(cx).clone()
3240 },
3241 cx,
3242 );
3243 });
3244
3245 let request = thread.update(cx, |thread, cx| {
3246 thread.to_completion_request(model.clone(), cx)
3247 });
3248 assert_eq!(request.temperature, Some(0.66));
3249
3250 // Only model
3251 cx.update(|cx| {
3252 AssistantSettings::override_global(
3253 AssistantSettings {
3254 model_parameters: vec![LanguageModelParameters {
3255 provider: None,
3256 model: Some(model.id().0.clone()),
3257 temperature: Some(0.66),
3258 }],
3259 ..AssistantSettings::get_global(cx).clone()
3260 },
3261 cx,
3262 );
3263 });
3264
3265 let request = thread.update(cx, |thread, cx| {
3266 thread.to_completion_request(model.clone(), cx)
3267 });
3268 assert_eq!(request.temperature, Some(0.66));
3269
3270 // Only provider
3271 cx.update(|cx| {
3272 AssistantSettings::override_global(
3273 AssistantSettings {
3274 model_parameters: vec![LanguageModelParameters {
3275 provider: Some(model.provider_id().0.to_string().into()),
3276 model: None,
3277 temperature: Some(0.66),
3278 }],
3279 ..AssistantSettings::get_global(cx).clone()
3280 },
3281 cx,
3282 );
3283 });
3284
3285 let request = thread.update(cx, |thread, cx| {
3286 thread.to_completion_request(model.clone(), cx)
3287 });
3288 assert_eq!(request.temperature, Some(0.66));
3289
3290 // Same model name, different provider
3291 cx.update(|cx| {
3292 AssistantSettings::override_global(
3293 AssistantSettings {
3294 model_parameters: vec![LanguageModelParameters {
3295 provider: Some("anthropic".into()),
3296 model: Some(model.id().0.clone()),
3297 temperature: Some(0.66),
3298 }],
3299 ..AssistantSettings::get_global(cx).clone()
3300 },
3301 cx,
3302 );
3303 });
3304
3305 let request = thread.update(cx, |thread, cx| {
3306 thread.to_completion_request(model.clone(), cx)
3307 });
3308 assert_eq!(request.temperature, None);
3309 }
3310
3311 #[gpui::test]
3312 async fn test_thread_summary(cx: &mut TestAppContext) {
3313 init_test_settings(cx);
3314
3315 let project = create_test_project(cx, json!({})).await;
3316
3317 let (_, _thread_store, thread, _context_store, model) =
3318 setup_test_environment(cx, project.clone()).await;
3319
3320 // Initial state should be pending
3321 thread.read_with(cx, |thread, _| {
3322 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3323 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3324 });
3325
3326 // Manually setting the summary should not be allowed in this state
3327 thread.update(cx, |thread, cx| {
3328 thread.set_summary("This should not work", cx);
3329 });
3330
3331 thread.read_with(cx, |thread, _| {
3332 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3333 });
3334
3335 // Send a message
3336 thread.update(cx, |thread, cx| {
3337 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3338 thread.send_to_model(model.clone(), None, cx);
3339 });
3340
3341 let fake_model = model.as_fake();
3342 simulate_successful_response(&fake_model, cx);
3343
3344 // Should start generating summary when there are >= 2 messages
3345 thread.read_with(cx, |thread, _| {
3346 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3347 });
3348
3349 // Should not be able to set the summary while generating
3350 thread.update(cx, |thread, cx| {
3351 thread.set_summary("This should not work either", cx);
3352 });
3353
3354 thread.read_with(cx, |thread, _| {
3355 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3356 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3357 });
3358
3359 cx.run_until_parked();
3360 fake_model.stream_last_completion_response("Brief".into());
3361 fake_model.stream_last_completion_response(" Introduction".into());
3362 fake_model.end_last_completion_stream();
3363 cx.run_until_parked();
3364
3365 // Summary should be set
3366 thread.read_with(cx, |thread, _| {
3367 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3368 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3369 });
3370
3371 // Now we should be able to set a summary
3372 thread.update(cx, |thread, cx| {
3373 thread.set_summary("Brief Intro", cx);
3374 });
3375
3376 thread.read_with(cx, |thread, _| {
3377 assert_eq!(thread.summary().or_default(), "Brief Intro");
3378 });
3379
3380 // Test setting an empty summary (should default to DEFAULT)
3381 thread.update(cx, |thread, cx| {
3382 thread.set_summary("", cx);
3383 });
3384
3385 thread.read_with(cx, |thread, _| {
3386 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3387 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3388 });
3389 }
3390
3391 #[gpui::test]
3392 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3393 init_test_settings(cx);
3394
3395 let project = create_test_project(cx, json!({})).await;
3396
3397 let (_, _thread_store, thread, _context_store, model) =
3398 setup_test_environment(cx, project.clone()).await;
3399
3400 test_summarize_error(&model, &thread, cx);
3401
3402 // Now we should be able to set a summary
3403 thread.update(cx, |thread, cx| {
3404 thread.set_summary("Brief Intro", cx);
3405 });
3406
3407 thread.read_with(cx, |thread, _| {
3408 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3409 assert_eq!(thread.summary().or_default(), "Brief Intro");
3410 });
3411 }
3412
3413 #[gpui::test]
3414 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3415 init_test_settings(cx);
3416
3417 let project = create_test_project(cx, json!({})).await;
3418
3419 let (_, _thread_store, thread, _context_store, model) =
3420 setup_test_environment(cx, project.clone()).await;
3421
3422 test_summarize_error(&model, &thread, cx);
3423
3424 // Sending another message should not trigger another summarize request
3425 thread.update(cx, |thread, cx| {
3426 thread.insert_user_message(
3427 "How are you?",
3428 ContextLoadResult::default(),
3429 None,
3430 vec![],
3431 cx,
3432 );
3433 thread.send_to_model(model.clone(), None, cx);
3434 });
3435
3436 let fake_model = model.as_fake();
3437 simulate_successful_response(&fake_model, cx);
3438
3439 thread.read_with(cx, |thread, _| {
3440 // State is still Error, not Generating
3441 assert!(matches!(thread.summary(), ThreadSummary::Error));
3442 });
3443
3444 // But the summarize request can be invoked manually
3445 thread.update(cx, |thread, cx| {
3446 thread.summarize(cx);
3447 });
3448
3449 thread.read_with(cx, |thread, _| {
3450 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3451 });
3452
3453 cx.run_until_parked();
3454 fake_model.stream_last_completion_response("A successful summary".into());
3455 fake_model.end_last_completion_stream();
3456 cx.run_until_parked();
3457
3458 thread.read_with(cx, |thread, _| {
3459 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3460 assert_eq!(thread.summary().or_default(), "A successful summary");
3461 });
3462 }
3463
3464 fn test_summarize_error(
3465 model: &Arc<dyn LanguageModel>,
3466 thread: &Entity<Thread>,
3467 cx: &mut TestAppContext,
3468 ) {
3469 thread.update(cx, |thread, cx| {
3470 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3471 thread.send_to_model(model.clone(), None, cx);
3472 });
3473
3474 let fake_model = model.as_fake();
3475 simulate_successful_response(&fake_model, cx);
3476
3477 thread.read_with(cx, |thread, _| {
3478 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3479 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3480 });
3481
3482 // Simulate summary request ending
3483 cx.run_until_parked();
3484 fake_model.end_last_completion_stream();
3485 cx.run_until_parked();
3486
3487 // State is set to Error and default message
3488 thread.read_with(cx, |thread, _| {
3489 assert!(matches!(thread.summary(), ThreadSummary::Error));
3490 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3491 });
3492 }
3493
3494 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3495 cx.run_until_parked();
3496 fake_model.stream_last_completion_response("Assistant response".into());
3497 fake_model.end_last_completion_stream();
3498 cx.run_until_parked();
3499 }
3500
3501 fn init_test_settings(cx: &mut TestAppContext) {
3502 cx.update(|cx| {
3503 let settings_store = SettingsStore::test(cx);
3504 cx.set_global(settings_store);
3505 language::init(cx);
3506 Project::init_settings(cx);
3507 AssistantSettings::register(cx);
3508 prompt_store::init(cx);
3509 thread_store::init(cx);
3510 workspace::init_settings(cx);
3511 language_model::init_settings(cx);
3512 ThemeSettings::register(cx);
3513 EditorSettings::register(cx);
3514 ToolRegistry::default_global(cx);
3515 });
3516 }
3517
3518 // Helper to create a test project with test files
3519 async fn create_test_project(
3520 cx: &mut TestAppContext,
3521 files: serde_json::Value,
3522 ) -> Entity<Project> {
3523 let fs = FakeFs::new(cx.executor());
3524 fs.insert_tree(path!("/test"), files).await;
3525 Project::test(fs, [path!("/test").as_ref()], cx).await
3526 }
3527
3528 async fn setup_test_environment(
3529 cx: &mut TestAppContext,
3530 project: Entity<Project>,
3531 ) -> (
3532 Entity<Workspace>,
3533 Entity<ThreadStore>,
3534 Entity<Thread>,
3535 Entity<ContextStore>,
3536 Arc<dyn LanguageModel>,
3537 ) {
3538 let (workspace, cx) =
3539 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3540
3541 let thread_store = cx
3542 .update(|_, cx| {
3543 ThreadStore::load(
3544 project.clone(),
3545 cx.new(|_| ToolWorkingSet::default()),
3546 None,
3547 Arc::new(PromptBuilder::new(None).unwrap()),
3548 cx,
3549 )
3550 })
3551 .await
3552 .unwrap();
3553
3554 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3555 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3556
3557 let provider = Arc::new(FakeLanguageModelProvider);
3558 let model = provider.test_model();
3559 let model: Arc<dyn LanguageModel> = Arc::new(model);
3560
3561 cx.update(|_, cx| {
3562 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3563 registry.set_default_model(
3564 Some(ConfiguredModel {
3565 provider: provider.clone(),
3566 model: model.clone(),
3567 }),
3568 cx,
3569 );
3570 registry.set_thread_summary_model(
3571 Some(ConfiguredModel {
3572 provider,
3573 model: model.clone(),
3574 }),
3575 cx,
3576 );
3577 })
3578 });
3579
3580 (workspace, thread_store, thread, context_store, model)
3581 }
3582
3583 async fn add_file_to_context(
3584 project: &Entity<Project>,
3585 context_store: &Entity<ContextStore>,
3586 path: &str,
3587 cx: &mut TestAppContext,
3588 ) -> Result<Entity<language::Buffer>> {
3589 let buffer_path = project
3590 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3591 .unwrap();
3592
3593 let buffer = project
3594 .update(cx, |project, cx| {
3595 project.open_buffer(buffer_path.clone(), cx)
3596 })
3597 .await
3598 .unwrap();
3599
3600 context_store.update(cx, |context_store, cx| {
3601 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3602 });
3603
3604 Ok(buffer)
3605 }
3606}