1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: agent_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 self.finalize_checkpoint(pending_checkpoint, cx);
761 }
762
763 fn finalize_checkpoint(
764 &mut self,
765 pending_checkpoint: ThreadCheckpoint,
766 cx: &mut Context<Self>,
767 ) {
768 let git_store = self.project.read(cx).git_store().clone();
769 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
770 cx.spawn(async move |this, cx| match final_checkpoint.await {
771 Ok(final_checkpoint) => {
772 let equal = git_store
773 .update(cx, |store, cx| {
774 store.compare_checkpoints(
775 pending_checkpoint.git_checkpoint.clone(),
776 final_checkpoint.clone(),
777 cx,
778 )
779 })?
780 .await
781 .unwrap_or(false);
782
783 if !equal {
784 this.update(cx, |this, cx| {
785 this.insert_checkpoint(pending_checkpoint, cx)
786 })?;
787 }
788
789 Ok(())
790 }
791 Err(_) => this.update(cx, |this, cx| {
792 this.insert_checkpoint(pending_checkpoint, cx)
793 }),
794 })
795 .detach();
796 }
797
798 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
799 self.checkpoints_by_message
800 .insert(checkpoint.message_id, checkpoint);
801 cx.emit(ThreadEvent::CheckpointChanged);
802 cx.notify();
803 }
804
805 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
806 self.last_restore_checkpoint.as_ref()
807 }
808
809 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
810 let Some(message_ix) = self
811 .messages
812 .iter()
813 .rposition(|message| message.id == message_id)
814 else {
815 return;
816 };
817 for deleted_message in self.messages.drain(message_ix..) {
818 self.checkpoints_by_message.remove(&deleted_message.id);
819 }
820 cx.notify();
821 }
822
823 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
824 self.messages
825 .iter()
826 .find(|message| message.id == id)
827 .into_iter()
828 .flat_map(|message| message.loaded_context.contexts.iter())
829 }
830
831 pub fn is_turn_end(&self, ix: usize) -> bool {
832 if self.messages.is_empty() {
833 return false;
834 }
835
836 if !self.is_generating() && ix == self.messages.len() - 1 {
837 return true;
838 }
839
840 let Some(message) = self.messages.get(ix) else {
841 return false;
842 };
843
844 if message.role != Role::Assistant {
845 return false;
846 }
847
848 self.messages
849 .get(ix + 1)
850 .and_then(|message| {
851 self.message(message.id)
852 .map(|next_message| next_message.role == Role::User)
853 })
854 .unwrap_or(false)
855 }
856
857 pub fn last_usage(&self) -> Option<RequestUsage> {
858 self.last_usage
859 }
860
861 pub fn tool_use_limit_reached(&self) -> bool {
862 self.tool_use_limit_reached
863 }
864
865 /// Returns whether all of the tool uses have finished running.
866 pub fn all_tools_finished(&self) -> bool {
867 // If the only pending tool uses left are the ones with errors, then
868 // that means that we've finished running all of the pending tools.
869 self.tool_use
870 .pending_tool_uses()
871 .iter()
872 .all(|tool_use| tool_use.status.is_error())
873 }
874
875 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
876 self.tool_use.tool_uses_for_message(id, cx)
877 }
878
879 pub fn tool_results_for_message(
880 &self,
881 assistant_message_id: MessageId,
882 ) -> Vec<&LanguageModelToolResult> {
883 self.tool_use.tool_results_for_message(assistant_message_id)
884 }
885
886 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
887 self.tool_use.tool_result(id)
888 }
889
890 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
891 match &self.tool_use.tool_result(id)?.content {
892 LanguageModelToolResultContent::Text(text)
893 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
894 Some(text)
895 }
896 LanguageModelToolResultContent::Image(_) => {
897 // TODO: We should display image
898 None
899 }
900 }
901 }
902
903 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
904 self.tool_use.tool_result_card(id).cloned()
905 }
906
907 /// Return tools that are both enabled and supported by the model
908 pub fn available_tools(
909 &self,
910 cx: &App,
911 model: Arc<dyn LanguageModel>,
912 ) -> Vec<LanguageModelRequestTool> {
913 if model.supports_tools() {
914 self.tools()
915 .read(cx)
916 .enabled_tools(cx)
917 .into_iter()
918 .filter_map(|tool| {
919 // Skip tools that cannot be supported
920 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
921 Some(LanguageModelRequestTool {
922 name: tool.name(),
923 description: tool.description(),
924 input_schema,
925 })
926 })
927 .collect()
928 } else {
929 Vec::default()
930 }
931 }
932
933 pub fn insert_user_message(
934 &mut self,
935 text: impl Into<String>,
936 loaded_context: ContextLoadResult,
937 git_checkpoint: Option<GitStoreCheckpoint>,
938 creases: Vec<MessageCrease>,
939 cx: &mut Context<Self>,
940 ) -> MessageId {
941 if !loaded_context.referenced_buffers.is_empty() {
942 self.action_log.update(cx, |log, cx| {
943 for buffer in loaded_context.referenced_buffers {
944 log.buffer_read(buffer, cx);
945 }
946 });
947 }
948
949 let message_id = self.insert_message(
950 Role::User,
951 vec![MessageSegment::Text(text.into())],
952 loaded_context.loaded_context,
953 creases,
954 cx,
955 );
956
957 if let Some(git_checkpoint) = git_checkpoint {
958 self.pending_checkpoint = Some(ThreadCheckpoint {
959 message_id,
960 git_checkpoint,
961 });
962 }
963
964 self.auto_capture_telemetry(cx);
965
966 message_id
967 }
968
969 pub fn insert_assistant_message(
970 &mut self,
971 segments: Vec<MessageSegment>,
972 cx: &mut Context<Self>,
973 ) -> MessageId {
974 self.insert_message(
975 Role::Assistant,
976 segments,
977 LoadedContext::default(),
978 Vec::new(),
979 cx,
980 )
981 }
982
983 pub fn insert_message(
984 &mut self,
985 role: Role,
986 segments: Vec<MessageSegment>,
987 loaded_context: LoadedContext,
988 creases: Vec<MessageCrease>,
989 cx: &mut Context<Self>,
990 ) -> MessageId {
991 let id = self.next_message_id.post_inc();
992 self.messages.push(Message {
993 id,
994 role,
995 segments,
996 loaded_context,
997 creases,
998 });
999 self.touch_updated_at();
1000 cx.emit(ThreadEvent::MessageAdded(id));
1001 id
1002 }
1003
1004 pub fn edit_message(
1005 &mut self,
1006 id: MessageId,
1007 new_role: Role,
1008 new_segments: Vec<MessageSegment>,
1009 loaded_context: Option<LoadedContext>,
1010 checkpoint: Option<GitStoreCheckpoint>,
1011 cx: &mut Context<Self>,
1012 ) -> bool {
1013 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1014 return false;
1015 };
1016 message.role = new_role;
1017 message.segments = new_segments;
1018 if let Some(context) = loaded_context {
1019 message.loaded_context = context;
1020 }
1021 if let Some(git_checkpoint) = checkpoint {
1022 self.checkpoints_by_message.insert(
1023 id,
1024 ThreadCheckpoint {
1025 message_id: id,
1026 git_checkpoint,
1027 },
1028 );
1029 }
1030 self.touch_updated_at();
1031 cx.emit(ThreadEvent::MessageEdited(id));
1032 true
1033 }
1034
1035 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1036 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1037 return false;
1038 };
1039 self.messages.remove(index);
1040 self.touch_updated_at();
1041 cx.emit(ThreadEvent::MessageDeleted(id));
1042 true
1043 }
1044
1045 /// Returns the representation of this [`Thread`] in a textual form.
1046 ///
1047 /// This is the representation we use when attaching a thread as context to another thread.
1048 pub fn text(&self) -> String {
1049 let mut text = String::new();
1050
1051 for message in &self.messages {
1052 text.push_str(match message.role {
1053 language_model::Role::User => "User:",
1054 language_model::Role::Assistant => "Agent:",
1055 language_model::Role::System => "System:",
1056 });
1057 text.push('\n');
1058
1059 for segment in &message.segments {
1060 match segment {
1061 MessageSegment::Text(content) => text.push_str(content),
1062 MessageSegment::Thinking { text: content, .. } => {
1063 text.push_str(&format!("<think>{}</think>", content))
1064 }
1065 MessageSegment::RedactedThinking(_) => {}
1066 }
1067 }
1068 text.push('\n');
1069 }
1070
1071 text
1072 }
1073
1074 /// Serializes this thread into a format for storage or telemetry.
1075 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1076 let initial_project_snapshot = self.initial_project_snapshot.clone();
1077 cx.spawn(async move |this, cx| {
1078 let initial_project_snapshot = initial_project_snapshot.await;
1079 this.read_with(cx, |this, cx| SerializedThread {
1080 version: SerializedThread::VERSION.to_string(),
1081 summary: this.summary().or_default(),
1082 updated_at: this.updated_at(),
1083 messages: this
1084 .messages()
1085 .map(|message| SerializedMessage {
1086 id: message.id,
1087 role: message.role,
1088 segments: message
1089 .segments
1090 .iter()
1091 .map(|segment| match segment {
1092 MessageSegment::Text(text) => {
1093 SerializedMessageSegment::Text { text: text.clone() }
1094 }
1095 MessageSegment::Thinking { text, signature } => {
1096 SerializedMessageSegment::Thinking {
1097 text: text.clone(),
1098 signature: signature.clone(),
1099 }
1100 }
1101 MessageSegment::RedactedThinking(data) => {
1102 SerializedMessageSegment::RedactedThinking {
1103 data: data.clone(),
1104 }
1105 }
1106 })
1107 .collect(),
1108 tool_uses: this
1109 .tool_uses_for_message(message.id, cx)
1110 .into_iter()
1111 .map(|tool_use| SerializedToolUse {
1112 id: tool_use.id,
1113 name: tool_use.name,
1114 input: tool_use.input,
1115 })
1116 .collect(),
1117 tool_results: this
1118 .tool_results_for_message(message.id)
1119 .into_iter()
1120 .map(|tool_result| SerializedToolResult {
1121 tool_use_id: tool_result.tool_use_id.clone(),
1122 is_error: tool_result.is_error,
1123 content: tool_result.content.clone(),
1124 output: tool_result.output.clone(),
1125 })
1126 .collect(),
1127 context: message.loaded_context.text.clone(),
1128 creases: message
1129 .creases
1130 .iter()
1131 .map(|crease| SerializedCrease {
1132 start: crease.range.start,
1133 end: crease.range.end,
1134 icon_path: crease.metadata.icon_path.clone(),
1135 label: crease.metadata.label.clone(),
1136 })
1137 .collect(),
1138 })
1139 .collect(),
1140 initial_project_snapshot,
1141 cumulative_token_usage: this.cumulative_token_usage,
1142 request_token_usage: this.request_token_usage.clone(),
1143 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1144 exceeded_window_error: this.exceeded_window_error.clone(),
1145 model: this
1146 .configured_model
1147 .as_ref()
1148 .map(|model| SerializedLanguageModel {
1149 provider: model.provider.id().0.to_string(),
1150 model: model.model.id().0.to_string(),
1151 }),
1152 completion_mode: Some(this.completion_mode),
1153 })
1154 })
1155 }
1156
1157 pub fn remaining_turns(&self) -> u32 {
1158 self.remaining_turns
1159 }
1160
1161 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1162 self.remaining_turns = remaining_turns;
1163 }
1164
1165 pub fn send_to_model(
1166 &mut self,
1167 model: Arc<dyn LanguageModel>,
1168 window: Option<AnyWindowHandle>,
1169 cx: &mut Context<Self>,
1170 ) {
1171 if self.remaining_turns == 0 {
1172 return;
1173 }
1174
1175 self.remaining_turns -= 1;
1176
1177 let request = self.to_completion_request(model.clone(), cx);
1178
1179 self.stream_completion(request, model, window, cx);
1180 }
1181
1182 pub fn used_tools_since_last_user_message(&self) -> bool {
1183 for message in self.messages.iter().rev() {
1184 if self.tool_use.message_has_tool_results(message.id) {
1185 return true;
1186 } else if message.role == Role::User {
1187 return false;
1188 }
1189 }
1190
1191 false
1192 }
1193
1194 pub fn to_completion_request(
1195 &self,
1196 model: Arc<dyn LanguageModel>,
1197 cx: &mut Context<Self>,
1198 ) -> LanguageModelRequest {
1199 let mut request = LanguageModelRequest {
1200 thread_id: Some(self.id.to_string()),
1201 prompt_id: Some(self.last_prompt_id.to_string()),
1202 mode: None,
1203 messages: vec![],
1204 tools: Vec::new(),
1205 tool_choice: None,
1206 stop: Vec::new(),
1207 temperature: AgentSettings::temperature_for_model(&model, cx),
1208 };
1209
1210 let available_tools = self.available_tools(cx, model.clone());
1211 let available_tool_names = available_tools
1212 .iter()
1213 .map(|tool| tool.name.clone())
1214 .collect();
1215
1216 let model_context = &ModelContext {
1217 available_tools: available_tool_names,
1218 };
1219
1220 if let Some(project_context) = self.project_context.borrow().as_ref() {
1221 match self
1222 .prompt_builder
1223 .generate_assistant_system_prompt(project_context, model_context)
1224 {
1225 Err(err) => {
1226 let message = format!("{err:?}").into();
1227 log::error!("{message}");
1228 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1229 header: "Error generating system prompt".into(),
1230 message,
1231 }));
1232 }
1233 Ok(system_prompt) => {
1234 request.messages.push(LanguageModelRequestMessage {
1235 role: Role::System,
1236 content: vec![MessageContent::Text(system_prompt)],
1237 cache: true,
1238 });
1239 }
1240 }
1241 } else {
1242 let message = "Context for system prompt unexpectedly not ready.".into();
1243 log::error!("{message}");
1244 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245 header: "Error generating system prompt".into(),
1246 message,
1247 }));
1248 }
1249
1250 let mut message_ix_to_cache = None;
1251 for message in &self.messages {
1252 let mut request_message = LanguageModelRequestMessage {
1253 role: message.role,
1254 content: Vec::new(),
1255 cache: false,
1256 };
1257
1258 message
1259 .loaded_context
1260 .add_to_request_message(&mut request_message);
1261
1262 for segment in &message.segments {
1263 match segment {
1264 MessageSegment::Text(text) => {
1265 if !text.is_empty() {
1266 request_message
1267 .content
1268 .push(MessageContent::Text(text.into()));
1269 }
1270 }
1271 MessageSegment::Thinking { text, signature } => {
1272 if !text.is_empty() {
1273 request_message.content.push(MessageContent::Thinking {
1274 text: text.into(),
1275 signature: signature.clone(),
1276 });
1277 }
1278 }
1279 MessageSegment::RedactedThinking(data) => {
1280 request_message
1281 .content
1282 .push(MessageContent::RedactedThinking(data.clone()));
1283 }
1284 };
1285 }
1286
1287 let mut cache_message = true;
1288 let mut tool_results_message = LanguageModelRequestMessage {
1289 role: Role::User,
1290 content: Vec::new(),
1291 cache: false,
1292 };
1293 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1294 if let Some(tool_result) = tool_result {
1295 request_message
1296 .content
1297 .push(MessageContent::ToolUse(tool_use.clone()));
1298 tool_results_message
1299 .content
1300 .push(MessageContent::ToolResult(LanguageModelToolResult {
1301 tool_use_id: tool_use.id.clone(),
1302 tool_name: tool_result.tool_name.clone(),
1303 is_error: tool_result.is_error,
1304 content: if tool_result.content.is_empty() {
1305 // Surprisingly, the API fails if we return an empty string here.
1306 // It thinks we are sending a tool use without a tool result.
1307 "<Tool returned an empty string>".into()
1308 } else {
1309 tool_result.content.clone()
1310 },
1311 output: None,
1312 }));
1313 } else {
1314 cache_message = false;
1315 log::debug!(
1316 "skipped tool use {:?} because it is still pending",
1317 tool_use
1318 );
1319 }
1320 }
1321
1322 if cache_message {
1323 message_ix_to_cache = Some(request.messages.len());
1324 }
1325 request.messages.push(request_message);
1326
1327 if !tool_results_message.content.is_empty() {
1328 if cache_message {
1329 message_ix_to_cache = Some(request.messages.len());
1330 }
1331 request.messages.push(tool_results_message);
1332 }
1333 }
1334
1335 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1336 if let Some(message_ix_to_cache) = message_ix_to_cache {
1337 request.messages[message_ix_to_cache].cache = true;
1338 }
1339
1340 self.attached_tracked_files_state(&mut request.messages, cx);
1341
1342 request.tools = available_tools;
1343 request.mode = if model.supports_max_mode() {
1344 Some(self.completion_mode.into())
1345 } else {
1346 Some(CompletionMode::Normal.into())
1347 };
1348
1349 request
1350 }
1351
1352 fn to_summarize_request(
1353 &self,
1354 model: &Arc<dyn LanguageModel>,
1355 added_user_message: String,
1356 cx: &App,
1357 ) -> LanguageModelRequest {
1358 let mut request = LanguageModelRequest {
1359 thread_id: None,
1360 prompt_id: None,
1361 mode: None,
1362 messages: vec![],
1363 tools: Vec::new(),
1364 tool_choice: None,
1365 stop: Vec::new(),
1366 temperature: AgentSettings::temperature_for_model(model, cx),
1367 };
1368
1369 for message in &self.messages {
1370 let mut request_message = LanguageModelRequestMessage {
1371 role: message.role,
1372 content: Vec::new(),
1373 cache: false,
1374 };
1375
1376 for segment in &message.segments {
1377 match segment {
1378 MessageSegment::Text(text) => request_message
1379 .content
1380 .push(MessageContent::Text(text.clone())),
1381 MessageSegment::Thinking { .. } => {}
1382 MessageSegment::RedactedThinking(_) => {}
1383 }
1384 }
1385
1386 if request_message.content.is_empty() {
1387 continue;
1388 }
1389
1390 request.messages.push(request_message);
1391 }
1392
1393 request.messages.push(LanguageModelRequestMessage {
1394 role: Role::User,
1395 content: vec![MessageContent::Text(added_user_message)],
1396 cache: false,
1397 });
1398
1399 request
1400 }
1401
1402 fn attached_tracked_files_state(
1403 &self,
1404 messages: &mut Vec<LanguageModelRequestMessage>,
1405 cx: &App,
1406 ) {
1407 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1408
1409 let mut stale_message = String::new();
1410
1411 let action_log = self.action_log.read(cx);
1412
1413 for stale_file in action_log.stale_buffers(cx) {
1414 let Some(file) = stale_file.read(cx).file() else {
1415 continue;
1416 };
1417
1418 if stale_message.is_empty() {
1419 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1420 }
1421
1422 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1423 }
1424
1425 let mut content = Vec::with_capacity(2);
1426
1427 if !stale_message.is_empty() {
1428 content.push(stale_message.into());
1429 }
1430
1431 if !content.is_empty() {
1432 let context_message = LanguageModelRequestMessage {
1433 role: Role::User,
1434 content,
1435 cache: false,
1436 };
1437
1438 messages.push(context_message);
1439 }
1440 }
1441
1442 pub fn stream_completion(
1443 &mut self,
1444 request: LanguageModelRequest,
1445 model: Arc<dyn LanguageModel>,
1446 window: Option<AnyWindowHandle>,
1447 cx: &mut Context<Self>,
1448 ) {
1449 self.tool_use_limit_reached = false;
1450
1451 let pending_completion_id = post_inc(&mut self.completion_count);
1452 let mut request_callback_parameters = if self.request_callback.is_some() {
1453 Some((request.clone(), Vec::new()))
1454 } else {
1455 None
1456 };
1457 let prompt_id = self.last_prompt_id.clone();
1458 let tool_use_metadata = ToolUseMetadata {
1459 model: model.clone(),
1460 thread_id: self.id.clone(),
1461 prompt_id: prompt_id.clone(),
1462 };
1463
1464 self.last_received_chunk_at = Some(Instant::now());
1465
1466 let task = cx.spawn(async move |thread, cx| {
1467 let stream_completion_future = model.stream_completion(request, &cx);
1468 let initial_token_usage =
1469 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1470 let stream_completion = async {
1471 let mut events = stream_completion_future.await?;
1472
1473 let mut stop_reason = StopReason::EndTurn;
1474 let mut current_token_usage = TokenUsage::default();
1475
1476 thread
1477 .update(cx, |_thread, cx| {
1478 cx.emit(ThreadEvent::NewRequest);
1479 })
1480 .ok();
1481
1482 let mut request_assistant_message_id = None;
1483
1484 while let Some(event) = events.next().await {
1485 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1486 response_events
1487 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1488 }
1489
1490 thread.update(cx, |thread, cx| {
1491 let event = match event {
1492 Ok(event) => event,
1493 Err(LanguageModelCompletionError::BadInputJson {
1494 id,
1495 tool_name,
1496 raw_input: invalid_input_json,
1497 json_parse_error,
1498 }) => {
1499 thread.receive_invalid_tool_json(
1500 id,
1501 tool_name,
1502 invalid_input_json,
1503 json_parse_error,
1504 window,
1505 cx,
1506 );
1507 return Ok(());
1508 }
1509 Err(LanguageModelCompletionError::Other(error)) => {
1510 return Err(error);
1511 }
1512 };
1513
1514 match event {
1515 LanguageModelCompletionEvent::StartMessage { .. } => {
1516 request_assistant_message_id =
1517 Some(thread.insert_assistant_message(
1518 vec![MessageSegment::Text(String::new())],
1519 cx,
1520 ));
1521 }
1522 LanguageModelCompletionEvent::Stop(reason) => {
1523 stop_reason = reason;
1524 }
1525 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1526 thread.update_token_usage_at_last_message(token_usage);
1527 thread.cumulative_token_usage = thread.cumulative_token_usage
1528 + token_usage
1529 - current_token_usage;
1530 current_token_usage = token_usage;
1531 }
1532 LanguageModelCompletionEvent::Text(chunk) => {
1533 thread.received_chunk();
1534
1535 cx.emit(ThreadEvent::ReceivedTextChunk);
1536 if let Some(last_message) = thread.messages.last_mut() {
1537 if last_message.role == Role::Assistant
1538 && !thread.tool_use.has_tool_results(last_message.id)
1539 {
1540 last_message.push_text(&chunk);
1541 cx.emit(ThreadEvent::StreamedAssistantText(
1542 last_message.id,
1543 chunk,
1544 ));
1545 } else {
1546 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1547 // of a new Assistant response.
1548 //
1549 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1550 // will result in duplicating the text of the chunk in the rendered Markdown.
1551 request_assistant_message_id =
1552 Some(thread.insert_assistant_message(
1553 vec![MessageSegment::Text(chunk.to_string())],
1554 cx,
1555 ));
1556 };
1557 }
1558 }
1559 LanguageModelCompletionEvent::Thinking {
1560 text: chunk,
1561 signature,
1562 } => {
1563 thread.received_chunk();
1564
1565 if let Some(last_message) = thread.messages.last_mut() {
1566 if last_message.role == Role::Assistant
1567 && !thread.tool_use.has_tool_results(last_message.id)
1568 {
1569 last_message.push_thinking(&chunk, signature);
1570 cx.emit(ThreadEvent::StreamedAssistantThinking(
1571 last_message.id,
1572 chunk,
1573 ));
1574 } else {
1575 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1576 // of a new Assistant response.
1577 //
1578 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1579 // will result in duplicating the text of the chunk in the rendered Markdown.
1580 request_assistant_message_id =
1581 Some(thread.insert_assistant_message(
1582 vec![MessageSegment::Thinking {
1583 text: chunk.to_string(),
1584 signature,
1585 }],
1586 cx,
1587 ));
1588 };
1589 }
1590 }
1591 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1592 let last_assistant_message_id = request_assistant_message_id
1593 .unwrap_or_else(|| {
1594 let new_assistant_message_id =
1595 thread.insert_assistant_message(vec![], cx);
1596 request_assistant_message_id =
1597 Some(new_assistant_message_id);
1598 new_assistant_message_id
1599 });
1600
1601 let tool_use_id = tool_use.id.clone();
1602 let streamed_input = if tool_use.is_input_complete {
1603 None
1604 } else {
1605 Some((&tool_use.input).clone())
1606 };
1607
1608 let ui_text = thread.tool_use.request_tool_use(
1609 last_assistant_message_id,
1610 tool_use,
1611 tool_use_metadata.clone(),
1612 cx,
1613 );
1614
1615 if let Some(input) = streamed_input {
1616 cx.emit(ThreadEvent::StreamedToolUse {
1617 tool_use_id,
1618 ui_text,
1619 input,
1620 });
1621 }
1622 }
1623 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1624 if let Some(completion) = thread
1625 .pending_completions
1626 .iter_mut()
1627 .find(|completion| completion.id == pending_completion_id)
1628 {
1629 match status_update {
1630 CompletionRequestStatus::Queued {
1631 position,
1632 } => {
1633 completion.queue_state = QueueState::Queued { position };
1634 }
1635 CompletionRequestStatus::Started => {
1636 completion.queue_state = QueueState::Started;
1637 }
1638 CompletionRequestStatus::Failed {
1639 code, message, request_id
1640 } => {
1641 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1642 }
1643 CompletionRequestStatus::UsageUpdated {
1644 amount, limit
1645 } => {
1646 let usage = RequestUsage { limit, amount: amount as i32 };
1647
1648 thread.last_usage = Some(usage);
1649 }
1650 CompletionRequestStatus::ToolUseLimitReached => {
1651 thread.tool_use_limit_reached = true;
1652 }
1653 }
1654 }
1655 }
1656 }
1657
1658 thread.touch_updated_at();
1659 cx.emit(ThreadEvent::StreamedCompletion);
1660 cx.notify();
1661
1662 thread.auto_capture_telemetry(cx);
1663 Ok(())
1664 })??;
1665
1666 smol::future::yield_now().await;
1667 }
1668
1669 thread.update(cx, |thread, cx| {
1670 thread.last_received_chunk_at = None;
1671 thread
1672 .pending_completions
1673 .retain(|completion| completion.id != pending_completion_id);
1674
1675 // If there is a response without tool use, summarize the message. Otherwise,
1676 // allow two tool uses before summarizing.
1677 if matches!(thread.summary, ThreadSummary::Pending)
1678 && thread.messages.len() >= 2
1679 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1680 {
1681 thread.summarize(cx);
1682 }
1683 })?;
1684
1685 anyhow::Ok(stop_reason)
1686 };
1687
1688 let result = stream_completion.await;
1689
1690 thread
1691 .update(cx, |thread, cx| {
1692 thread.finalize_pending_checkpoint(cx);
1693 match result.as_ref() {
1694 Ok(stop_reason) => match stop_reason {
1695 StopReason::ToolUse => {
1696 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1697 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1698 }
1699 StopReason::EndTurn | StopReason::MaxTokens => {
1700 thread.project.update(cx, |project, cx| {
1701 project.set_agent_location(None, cx);
1702 });
1703 }
1704 StopReason::Refusal => {
1705 thread.project.update(cx, |project, cx| {
1706 project.set_agent_location(None, cx);
1707 });
1708
1709 // Remove the turn that was refused.
1710 //
1711 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1712 {
1713 let mut messages_to_remove = Vec::new();
1714
1715 for (ix, message) in thread.messages.iter().enumerate().rev() {
1716 messages_to_remove.push(message.id);
1717
1718 if message.role == Role::User {
1719 if ix == 0 {
1720 break;
1721 }
1722
1723 if let Some(prev_message) = thread.messages.get(ix - 1) {
1724 if prev_message.role == Role::Assistant {
1725 break;
1726 }
1727 }
1728 }
1729 }
1730
1731 for message_id in messages_to_remove {
1732 thread.delete_message(message_id, cx);
1733 }
1734 }
1735
1736 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1737 header: "Language model refusal".into(),
1738 message: "Model refused to generate content for safety reasons.".into(),
1739 }));
1740 }
1741 },
1742 Err(error) => {
1743 thread.project.update(cx, |project, cx| {
1744 project.set_agent_location(None, cx);
1745 });
1746
1747 if error.is::<PaymentRequiredError>() {
1748 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1749 } else if let Some(error) =
1750 error.downcast_ref::<ModelRequestLimitReachedError>()
1751 {
1752 cx.emit(ThreadEvent::ShowError(
1753 ThreadError::ModelRequestLimitReached { plan: error.plan },
1754 ));
1755 } else if let Some(known_error) =
1756 error.downcast_ref::<LanguageModelKnownError>()
1757 {
1758 match known_error {
1759 LanguageModelKnownError::ContextWindowLimitExceeded {
1760 tokens,
1761 } => {
1762 thread.exceeded_window_error = Some(ExceededWindowError {
1763 model_id: model.id(),
1764 token_count: *tokens,
1765 });
1766 cx.notify();
1767 }
1768 }
1769 } else {
1770 let error_message = error
1771 .chain()
1772 .map(|err| err.to_string())
1773 .collect::<Vec<_>>()
1774 .join("\n");
1775 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1776 header: "Error interacting with language model".into(),
1777 message: SharedString::from(error_message.clone()),
1778 }));
1779 }
1780
1781 thread.cancel_last_completion(window, cx);
1782 }
1783 }
1784 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1785
1786 if let Some((request_callback, (request, response_events))) = thread
1787 .request_callback
1788 .as_mut()
1789 .zip(request_callback_parameters.as_ref())
1790 {
1791 request_callback(request, response_events);
1792 }
1793
1794 thread.auto_capture_telemetry(cx);
1795
1796 if let Ok(initial_usage) = initial_token_usage {
1797 let usage = thread.cumulative_token_usage - initial_usage;
1798
1799 telemetry::event!(
1800 "Assistant Thread Completion",
1801 thread_id = thread.id().to_string(),
1802 prompt_id = prompt_id,
1803 model = model.telemetry_id(),
1804 model_provider = model.provider_id().to_string(),
1805 input_tokens = usage.input_tokens,
1806 output_tokens = usage.output_tokens,
1807 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1808 cache_read_input_tokens = usage.cache_read_input_tokens,
1809 );
1810 }
1811 })
1812 .ok();
1813 });
1814
1815 self.pending_completions.push(PendingCompletion {
1816 id: pending_completion_id,
1817 queue_state: QueueState::Sending,
1818 _task: task,
1819 });
1820 }
1821
1822 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1823 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1824 println!("No thread summary model");
1825 return;
1826 };
1827
1828 if !model.provider.is_authenticated(cx) {
1829 return;
1830 }
1831
1832 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1833 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1834 If the conversation is about a specific subject, include it in the title. \
1835 Be descriptive. DO NOT speak in the first person.";
1836
1837 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1838
1839 self.summary = ThreadSummary::Generating;
1840
1841 self.pending_summary = cx.spawn(async move |this, cx| {
1842 let result = async {
1843 let mut messages = model.model.stream_completion(request, &cx).await?;
1844
1845 let mut new_summary = String::new();
1846 while let Some(event) = messages.next().await {
1847 let Ok(event) = event else {
1848 continue;
1849 };
1850 let text = match event {
1851 LanguageModelCompletionEvent::Text(text) => text,
1852 LanguageModelCompletionEvent::StatusUpdate(
1853 CompletionRequestStatus::UsageUpdated { amount, limit },
1854 ) => {
1855 this.update(cx, |thread, _cx| {
1856 thread.last_usage = Some(RequestUsage {
1857 limit,
1858 amount: amount as i32,
1859 });
1860 })?;
1861 continue;
1862 }
1863 _ => continue,
1864 };
1865
1866 let mut lines = text.lines();
1867 new_summary.extend(lines.next());
1868
1869 // Stop if the LLM generated multiple lines.
1870 if lines.next().is_some() {
1871 break;
1872 }
1873 }
1874
1875 anyhow::Ok(new_summary)
1876 }
1877 .await;
1878
1879 this.update(cx, |this, cx| {
1880 match result {
1881 Ok(new_summary) => {
1882 if new_summary.is_empty() {
1883 this.summary = ThreadSummary::Error;
1884 } else {
1885 this.summary = ThreadSummary::Ready(new_summary.into());
1886 }
1887 }
1888 Err(err) => {
1889 this.summary = ThreadSummary::Error;
1890 log::error!("Failed to generate thread summary: {}", err);
1891 }
1892 }
1893 cx.emit(ThreadEvent::SummaryGenerated);
1894 })
1895 .log_err()?;
1896
1897 Some(())
1898 });
1899 }
1900
1901 pub fn start_generating_detailed_summary_if_needed(
1902 &mut self,
1903 thread_store: WeakEntity<ThreadStore>,
1904 cx: &mut Context<Self>,
1905 ) {
1906 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1907 return;
1908 };
1909
1910 match &*self.detailed_summary_rx.borrow() {
1911 DetailedSummaryState::Generating { message_id, .. }
1912 | DetailedSummaryState::Generated { message_id, .. }
1913 if *message_id == last_message_id =>
1914 {
1915 // Already up-to-date
1916 return;
1917 }
1918 _ => {}
1919 }
1920
1921 let Some(ConfiguredModel { model, provider }) =
1922 LanguageModelRegistry::read_global(cx).thread_summary_model()
1923 else {
1924 return;
1925 };
1926
1927 if !provider.is_authenticated(cx) {
1928 return;
1929 }
1930
1931 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1932 1. A brief overview of what was discussed\n\
1933 2. Key facts or information discovered\n\
1934 3. Outcomes or conclusions reached\n\
1935 4. Any action items or next steps if any\n\
1936 Format it in Markdown with headings and bullet points.";
1937
1938 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1939
1940 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1941 message_id: last_message_id,
1942 };
1943
1944 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1945 // be better to allow the old task to complete, but this would require logic for choosing
1946 // which result to prefer (the old task could complete after the new one, resulting in a
1947 // stale summary).
1948 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1949 let stream = model.stream_completion_text(request, &cx);
1950 let Some(mut messages) = stream.await.log_err() else {
1951 thread
1952 .update(cx, |thread, _cx| {
1953 *thread.detailed_summary_tx.borrow_mut() =
1954 DetailedSummaryState::NotGenerated;
1955 })
1956 .ok()?;
1957 return None;
1958 };
1959
1960 let mut new_detailed_summary = String::new();
1961
1962 while let Some(chunk) = messages.stream.next().await {
1963 if let Some(chunk) = chunk.log_err() {
1964 new_detailed_summary.push_str(&chunk);
1965 }
1966 }
1967
1968 thread
1969 .update(cx, |thread, _cx| {
1970 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1971 text: new_detailed_summary.into(),
1972 message_id: last_message_id,
1973 };
1974 })
1975 .ok()?;
1976
1977 // Save thread so its summary can be reused later
1978 if let Some(thread) = thread.upgrade() {
1979 if let Ok(Ok(save_task)) = cx.update(|cx| {
1980 thread_store
1981 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1982 }) {
1983 save_task.await.log_err();
1984 }
1985 }
1986
1987 Some(())
1988 });
1989 }
1990
1991 pub async fn wait_for_detailed_summary_or_text(
1992 this: &Entity<Self>,
1993 cx: &mut AsyncApp,
1994 ) -> Option<SharedString> {
1995 let mut detailed_summary_rx = this
1996 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1997 .ok()?;
1998 loop {
1999 match detailed_summary_rx.recv().await? {
2000 DetailedSummaryState::Generating { .. } => {}
2001 DetailedSummaryState::NotGenerated => {
2002 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2003 }
2004 DetailedSummaryState::Generated { text, .. } => return Some(text),
2005 }
2006 }
2007 }
2008
2009 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2010 self.detailed_summary_rx
2011 .borrow()
2012 .text()
2013 .unwrap_or_else(|| self.text().into())
2014 }
2015
2016 pub fn is_generating_detailed_summary(&self) -> bool {
2017 matches!(
2018 &*self.detailed_summary_rx.borrow(),
2019 DetailedSummaryState::Generating { .. }
2020 )
2021 }
2022
2023 pub fn use_pending_tools(
2024 &mut self,
2025 window: Option<AnyWindowHandle>,
2026 cx: &mut Context<Self>,
2027 model: Arc<dyn LanguageModel>,
2028 ) -> Vec<PendingToolUse> {
2029 self.auto_capture_telemetry(cx);
2030 let request = Arc::new(self.to_completion_request(model.clone(), cx));
2031 let pending_tool_uses = self
2032 .tool_use
2033 .pending_tool_uses()
2034 .into_iter()
2035 .filter(|tool_use| tool_use.status.is_idle())
2036 .cloned()
2037 .collect::<Vec<_>>();
2038
2039 for tool_use in pending_tool_uses.iter() {
2040 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2041 if tool.needs_confirmation(&tool_use.input, cx)
2042 && !AgentSettings::get_global(cx).always_allow_tool_actions
2043 {
2044 self.tool_use.confirm_tool_use(
2045 tool_use.id.clone(),
2046 tool_use.ui_text.clone(),
2047 tool_use.input.clone(),
2048 request.clone(),
2049 tool,
2050 );
2051 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2052 } else {
2053 self.run_tool(
2054 tool_use.id.clone(),
2055 tool_use.ui_text.clone(),
2056 tool_use.input.clone(),
2057 request.clone(),
2058 tool,
2059 model.clone(),
2060 window,
2061 cx,
2062 );
2063 }
2064 } else {
2065 self.handle_hallucinated_tool_use(
2066 tool_use.id.clone(),
2067 tool_use.name.clone(),
2068 window,
2069 cx,
2070 );
2071 }
2072 }
2073
2074 pending_tool_uses
2075 }
2076
2077 pub fn handle_hallucinated_tool_use(
2078 &mut self,
2079 tool_use_id: LanguageModelToolUseId,
2080 hallucinated_tool_name: Arc<str>,
2081 window: Option<AnyWindowHandle>,
2082 cx: &mut Context<Thread>,
2083 ) {
2084 let available_tools = self.tools.read(cx).enabled_tools(cx);
2085
2086 let tool_list = available_tools
2087 .iter()
2088 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2089 .collect::<Vec<_>>()
2090 .join("\n");
2091
2092 let error_message = format!(
2093 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2094 hallucinated_tool_name, tool_list
2095 );
2096
2097 let pending_tool_use = self.tool_use.insert_tool_output(
2098 tool_use_id.clone(),
2099 hallucinated_tool_name,
2100 Err(anyhow!("Missing tool call: {error_message}")),
2101 self.configured_model.as_ref(),
2102 );
2103
2104 cx.emit(ThreadEvent::MissingToolUse {
2105 tool_use_id: tool_use_id.clone(),
2106 ui_text: error_message.into(),
2107 });
2108
2109 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2110 }
2111
2112 pub fn receive_invalid_tool_json(
2113 &mut self,
2114 tool_use_id: LanguageModelToolUseId,
2115 tool_name: Arc<str>,
2116 invalid_json: Arc<str>,
2117 error: String,
2118 window: Option<AnyWindowHandle>,
2119 cx: &mut Context<Thread>,
2120 ) {
2121 log::error!("The model returned invalid input JSON: {invalid_json}");
2122
2123 let pending_tool_use = self.tool_use.insert_tool_output(
2124 tool_use_id.clone(),
2125 tool_name,
2126 Err(anyhow!("Error parsing input JSON: {error}")),
2127 self.configured_model.as_ref(),
2128 );
2129 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2130 pending_tool_use.ui_text.clone()
2131 } else {
2132 log::error!(
2133 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2134 );
2135 format!("Unknown tool {}", tool_use_id).into()
2136 };
2137
2138 cx.emit(ThreadEvent::InvalidToolInput {
2139 tool_use_id: tool_use_id.clone(),
2140 ui_text,
2141 invalid_input_json: invalid_json,
2142 });
2143
2144 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2145 }
2146
2147 pub fn run_tool(
2148 &mut self,
2149 tool_use_id: LanguageModelToolUseId,
2150 ui_text: impl Into<SharedString>,
2151 input: serde_json::Value,
2152 request: Arc<LanguageModelRequest>,
2153 tool: Arc<dyn Tool>,
2154 model: Arc<dyn LanguageModel>,
2155 window: Option<AnyWindowHandle>,
2156 cx: &mut Context<Thread>,
2157 ) {
2158 let task =
2159 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2160 self.tool_use
2161 .run_pending_tool(tool_use_id, ui_text.into(), task);
2162 }
2163
2164 fn spawn_tool_use(
2165 &mut self,
2166 tool_use_id: LanguageModelToolUseId,
2167 request: Arc<LanguageModelRequest>,
2168 input: serde_json::Value,
2169 tool: Arc<dyn Tool>,
2170 model: Arc<dyn LanguageModel>,
2171 window: Option<AnyWindowHandle>,
2172 cx: &mut Context<Thread>,
2173 ) -> Task<()> {
2174 let tool_name: Arc<str> = tool.name().into();
2175
2176 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2177 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2178 } else {
2179 tool.run(
2180 input,
2181 request,
2182 self.project.clone(),
2183 self.action_log.clone(),
2184 model,
2185 window,
2186 cx,
2187 )
2188 };
2189
2190 // Store the card separately if it exists
2191 if let Some(card) = tool_result.card.clone() {
2192 self.tool_use
2193 .insert_tool_result_card(tool_use_id.clone(), card);
2194 }
2195
2196 cx.spawn({
2197 async move |thread: WeakEntity<Thread>, cx| {
2198 let output = tool_result.output.await;
2199
2200 thread
2201 .update(cx, |thread, cx| {
2202 let pending_tool_use = thread.tool_use.insert_tool_output(
2203 tool_use_id.clone(),
2204 tool_name,
2205 output,
2206 thread.configured_model.as_ref(),
2207 );
2208 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2209 })
2210 .ok();
2211 }
2212 })
2213 }
2214
2215 fn tool_finished(
2216 &mut self,
2217 tool_use_id: LanguageModelToolUseId,
2218 pending_tool_use: Option<PendingToolUse>,
2219 canceled: bool,
2220 window: Option<AnyWindowHandle>,
2221 cx: &mut Context<Self>,
2222 ) {
2223 if self.all_tools_finished() {
2224 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2225 if !canceled {
2226 self.send_to_model(model.clone(), window, cx);
2227 }
2228 self.auto_capture_telemetry(cx);
2229 }
2230 }
2231
2232 cx.emit(ThreadEvent::ToolFinished {
2233 tool_use_id,
2234 pending_tool_use,
2235 });
2236 }
2237
2238 /// Cancels the last pending completion, if there are any pending.
2239 ///
2240 /// Returns whether a completion was canceled.
2241 pub fn cancel_last_completion(
2242 &mut self,
2243 window: Option<AnyWindowHandle>,
2244 cx: &mut Context<Self>,
2245 ) -> bool {
2246 let mut canceled = self.pending_completions.pop().is_some();
2247
2248 for pending_tool_use in self.tool_use.cancel_pending() {
2249 canceled = true;
2250 self.tool_finished(
2251 pending_tool_use.id.clone(),
2252 Some(pending_tool_use),
2253 true,
2254 window,
2255 cx,
2256 );
2257 }
2258
2259 if canceled {
2260 cx.emit(ThreadEvent::CompletionCanceled);
2261
2262 // When canceled, we always want to insert the checkpoint.
2263 // (We skip over finalize_pending_checkpoint, because it
2264 // would conclude we didn't have anything to insert here.)
2265 if let Some(checkpoint) = self.pending_checkpoint.take() {
2266 self.insert_checkpoint(checkpoint, cx);
2267 }
2268 } else {
2269 self.finalize_pending_checkpoint(cx);
2270 }
2271
2272 canceled
2273 }
2274
2275 /// Signals that any in-progress editing should be canceled.
2276 ///
2277 /// This method is used to notify listeners (like ActiveThread) that
2278 /// they should cancel any editing operations.
2279 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2280 cx.emit(ThreadEvent::CancelEditing);
2281 }
2282
2283 pub fn feedback(&self) -> Option<ThreadFeedback> {
2284 self.feedback
2285 }
2286
2287 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2288 self.message_feedback.get(&message_id).copied()
2289 }
2290
2291 pub fn report_message_feedback(
2292 &mut self,
2293 message_id: MessageId,
2294 feedback: ThreadFeedback,
2295 cx: &mut Context<Self>,
2296 ) -> Task<Result<()>> {
2297 if self.message_feedback.get(&message_id) == Some(&feedback) {
2298 return Task::ready(Ok(()));
2299 }
2300
2301 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2302 let serialized_thread = self.serialize(cx);
2303 let thread_id = self.id().clone();
2304 let client = self.project.read(cx).client();
2305
2306 let enabled_tool_names: Vec<String> = self
2307 .tools()
2308 .read(cx)
2309 .enabled_tools(cx)
2310 .iter()
2311 .map(|tool| tool.name())
2312 .collect();
2313
2314 self.message_feedback.insert(message_id, feedback);
2315
2316 cx.notify();
2317
2318 let message_content = self
2319 .message(message_id)
2320 .map(|msg| msg.to_string())
2321 .unwrap_or_default();
2322
2323 cx.background_spawn(async move {
2324 let final_project_snapshot = final_project_snapshot.await;
2325 let serialized_thread = serialized_thread.await?;
2326 let thread_data =
2327 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2328
2329 let rating = match feedback {
2330 ThreadFeedback::Positive => "positive",
2331 ThreadFeedback::Negative => "negative",
2332 };
2333 telemetry::event!(
2334 "Assistant Thread Rated",
2335 rating,
2336 thread_id,
2337 enabled_tool_names,
2338 message_id = message_id.0,
2339 message_content,
2340 thread_data,
2341 final_project_snapshot
2342 );
2343 client.telemetry().flush_events().await;
2344
2345 Ok(())
2346 })
2347 }
2348
2349 pub fn report_feedback(
2350 &mut self,
2351 feedback: ThreadFeedback,
2352 cx: &mut Context<Self>,
2353 ) -> Task<Result<()>> {
2354 let last_assistant_message_id = self
2355 .messages
2356 .iter()
2357 .rev()
2358 .find(|msg| msg.role == Role::Assistant)
2359 .map(|msg| msg.id);
2360
2361 if let Some(message_id) = last_assistant_message_id {
2362 self.report_message_feedback(message_id, feedback, cx)
2363 } else {
2364 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2365 let serialized_thread = self.serialize(cx);
2366 let thread_id = self.id().clone();
2367 let client = self.project.read(cx).client();
2368 self.feedback = Some(feedback);
2369 cx.notify();
2370
2371 cx.background_spawn(async move {
2372 let final_project_snapshot = final_project_snapshot.await;
2373 let serialized_thread = serialized_thread.await?;
2374 let thread_data = serde_json::to_value(serialized_thread)
2375 .unwrap_or_else(|_| serde_json::Value::Null);
2376
2377 let rating = match feedback {
2378 ThreadFeedback::Positive => "positive",
2379 ThreadFeedback::Negative => "negative",
2380 };
2381 telemetry::event!(
2382 "Assistant Thread Rated",
2383 rating,
2384 thread_id,
2385 thread_data,
2386 final_project_snapshot
2387 );
2388 client.telemetry().flush_events().await;
2389
2390 Ok(())
2391 })
2392 }
2393 }
2394
2395 /// Create a snapshot of the current project state including git information and unsaved buffers.
2396 fn project_snapshot(
2397 project: Entity<Project>,
2398 cx: &mut Context<Self>,
2399 ) -> Task<Arc<ProjectSnapshot>> {
2400 let git_store = project.read(cx).git_store().clone();
2401 let worktree_snapshots: Vec<_> = project
2402 .read(cx)
2403 .visible_worktrees(cx)
2404 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2405 .collect();
2406
2407 cx.spawn(async move |_, cx| {
2408 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2409
2410 let mut unsaved_buffers = Vec::new();
2411 cx.update(|app_cx| {
2412 let buffer_store = project.read(app_cx).buffer_store();
2413 for buffer_handle in buffer_store.read(app_cx).buffers() {
2414 let buffer = buffer_handle.read(app_cx);
2415 if buffer.is_dirty() {
2416 if let Some(file) = buffer.file() {
2417 let path = file.path().to_string_lossy().to_string();
2418 unsaved_buffers.push(path);
2419 }
2420 }
2421 }
2422 })
2423 .ok();
2424
2425 Arc::new(ProjectSnapshot {
2426 worktree_snapshots,
2427 unsaved_buffer_paths: unsaved_buffers,
2428 timestamp: Utc::now(),
2429 })
2430 })
2431 }
2432
2433 fn worktree_snapshot(
2434 worktree: Entity<project::Worktree>,
2435 git_store: Entity<GitStore>,
2436 cx: &App,
2437 ) -> Task<WorktreeSnapshot> {
2438 cx.spawn(async move |cx| {
2439 // Get worktree path and snapshot
2440 let worktree_info = cx.update(|app_cx| {
2441 let worktree = worktree.read(app_cx);
2442 let path = worktree.abs_path().to_string_lossy().to_string();
2443 let snapshot = worktree.snapshot();
2444 (path, snapshot)
2445 });
2446
2447 let Ok((worktree_path, _snapshot)) = worktree_info else {
2448 return WorktreeSnapshot {
2449 worktree_path: String::new(),
2450 git_state: None,
2451 };
2452 };
2453
2454 let git_state = git_store
2455 .update(cx, |git_store, cx| {
2456 git_store
2457 .repositories()
2458 .values()
2459 .find(|repo| {
2460 repo.read(cx)
2461 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2462 .is_some()
2463 })
2464 .cloned()
2465 })
2466 .ok()
2467 .flatten()
2468 .map(|repo| {
2469 repo.update(cx, |repo, _| {
2470 let current_branch =
2471 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2472 repo.send_job(None, |state, _| async move {
2473 let RepositoryState::Local { backend, .. } = state else {
2474 return GitState {
2475 remote_url: None,
2476 head_sha: None,
2477 current_branch,
2478 diff: None,
2479 };
2480 };
2481
2482 let remote_url = backend.remote_url("origin");
2483 let head_sha = backend.head_sha().await;
2484 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2485
2486 GitState {
2487 remote_url,
2488 head_sha,
2489 current_branch,
2490 diff,
2491 }
2492 })
2493 })
2494 });
2495
2496 let git_state = match git_state {
2497 Some(git_state) => match git_state.ok() {
2498 Some(git_state) => git_state.await.ok(),
2499 None => None,
2500 },
2501 None => None,
2502 };
2503
2504 WorktreeSnapshot {
2505 worktree_path,
2506 git_state,
2507 }
2508 })
2509 }
2510
2511 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2512 let mut markdown = Vec::new();
2513
2514 let summary = self.summary().or_default();
2515 writeln!(markdown, "# {summary}\n")?;
2516
2517 for message in self.messages() {
2518 writeln!(
2519 markdown,
2520 "## {role}\n",
2521 role = match message.role {
2522 Role::User => "User",
2523 Role::Assistant => "Agent",
2524 Role::System => "System",
2525 }
2526 )?;
2527
2528 if !message.loaded_context.text.is_empty() {
2529 writeln!(markdown, "{}", message.loaded_context.text)?;
2530 }
2531
2532 if !message.loaded_context.images.is_empty() {
2533 writeln!(
2534 markdown,
2535 "\n{} images attached as context.\n",
2536 message.loaded_context.images.len()
2537 )?;
2538 }
2539
2540 for segment in &message.segments {
2541 match segment {
2542 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2543 MessageSegment::Thinking { text, .. } => {
2544 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2545 }
2546 MessageSegment::RedactedThinking(_) => {}
2547 }
2548 }
2549
2550 for tool_use in self.tool_uses_for_message(message.id, cx) {
2551 writeln!(
2552 markdown,
2553 "**Use Tool: {} ({})**",
2554 tool_use.name, tool_use.id
2555 )?;
2556 writeln!(markdown, "```json")?;
2557 writeln!(
2558 markdown,
2559 "{}",
2560 serde_json::to_string_pretty(&tool_use.input)?
2561 )?;
2562 writeln!(markdown, "```")?;
2563 }
2564
2565 for tool_result in self.tool_results_for_message(message.id) {
2566 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2567 if tool_result.is_error {
2568 write!(markdown, " (Error)")?;
2569 }
2570
2571 writeln!(markdown, "**\n")?;
2572 match &tool_result.content {
2573 LanguageModelToolResultContent::Text(text)
2574 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2575 text,
2576 ..
2577 }) => {
2578 writeln!(markdown, "{text}")?;
2579 }
2580 LanguageModelToolResultContent::Image(image) => {
2581 writeln!(markdown, "", image.source)?;
2582 }
2583 }
2584
2585 if let Some(output) = tool_result.output.as_ref() {
2586 writeln!(
2587 markdown,
2588 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2589 serde_json::to_string_pretty(output)?
2590 )?;
2591 }
2592 }
2593 }
2594
2595 Ok(String::from_utf8_lossy(&markdown).to_string())
2596 }
2597
2598 pub fn keep_edits_in_range(
2599 &mut self,
2600 buffer: Entity<language::Buffer>,
2601 buffer_range: Range<language::Anchor>,
2602 cx: &mut Context<Self>,
2603 ) {
2604 self.action_log.update(cx, |action_log, cx| {
2605 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2606 });
2607 }
2608
2609 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2610 self.action_log
2611 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2612 }
2613
2614 pub fn reject_edits_in_ranges(
2615 &mut self,
2616 buffer: Entity<language::Buffer>,
2617 buffer_ranges: Vec<Range<language::Anchor>>,
2618 cx: &mut Context<Self>,
2619 ) -> Task<Result<()>> {
2620 self.action_log.update(cx, |action_log, cx| {
2621 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2622 })
2623 }
2624
2625 pub fn action_log(&self) -> &Entity<ActionLog> {
2626 &self.action_log
2627 }
2628
2629 pub fn project(&self) -> &Entity<Project> {
2630 &self.project
2631 }
2632
2633 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2634 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2635 return;
2636 }
2637
2638 let now = Instant::now();
2639 if let Some(last) = self.last_auto_capture_at {
2640 if now.duration_since(last).as_secs() < 10 {
2641 return;
2642 }
2643 }
2644
2645 self.last_auto_capture_at = Some(now);
2646
2647 let thread_id = self.id().clone();
2648 let github_login = self
2649 .project
2650 .read(cx)
2651 .user_store()
2652 .read(cx)
2653 .current_user()
2654 .map(|user| user.github_login.clone());
2655 let client = self.project.read(cx).client();
2656 let serialize_task = self.serialize(cx);
2657
2658 cx.background_executor()
2659 .spawn(async move {
2660 if let Ok(serialized_thread) = serialize_task.await {
2661 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2662 telemetry::event!(
2663 "Agent Thread Auto-Captured",
2664 thread_id = thread_id.to_string(),
2665 thread_data = thread_data,
2666 auto_capture_reason = "tracked_user",
2667 github_login = github_login
2668 );
2669
2670 client.telemetry().flush_events().await;
2671 }
2672 }
2673 })
2674 .detach();
2675 }
2676
2677 pub fn cumulative_token_usage(&self) -> TokenUsage {
2678 self.cumulative_token_usage
2679 }
2680
2681 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2682 let Some(model) = self.configured_model.as_ref() else {
2683 return TotalTokenUsage::default();
2684 };
2685
2686 let max = model.model.max_token_count();
2687
2688 let index = self
2689 .messages
2690 .iter()
2691 .position(|msg| msg.id == message_id)
2692 .unwrap_or(0);
2693
2694 if index == 0 {
2695 return TotalTokenUsage { total: 0, max };
2696 }
2697
2698 let token_usage = &self
2699 .request_token_usage
2700 .get(index - 1)
2701 .cloned()
2702 .unwrap_or_default();
2703
2704 TotalTokenUsage {
2705 total: token_usage.total_tokens() as usize,
2706 max,
2707 }
2708 }
2709
2710 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2711 let model = self.configured_model.as_ref()?;
2712
2713 let max = model.model.max_token_count();
2714
2715 if let Some(exceeded_error) = &self.exceeded_window_error {
2716 if model.model.id() == exceeded_error.model_id {
2717 return Some(TotalTokenUsage {
2718 total: exceeded_error.token_count,
2719 max,
2720 });
2721 }
2722 }
2723
2724 let total = self
2725 .token_usage_at_last_message()
2726 .unwrap_or_default()
2727 .total_tokens() as usize;
2728
2729 Some(TotalTokenUsage { total, max })
2730 }
2731
2732 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2733 self.request_token_usage
2734 .get(self.messages.len().saturating_sub(1))
2735 .or_else(|| self.request_token_usage.last())
2736 .cloned()
2737 }
2738
2739 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2740 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2741 self.request_token_usage
2742 .resize(self.messages.len(), placeholder);
2743
2744 if let Some(last) = self.request_token_usage.last_mut() {
2745 *last = token_usage;
2746 }
2747 }
2748
2749 pub fn deny_tool_use(
2750 &mut self,
2751 tool_use_id: LanguageModelToolUseId,
2752 tool_name: Arc<str>,
2753 window: Option<AnyWindowHandle>,
2754 cx: &mut Context<Self>,
2755 ) {
2756 let err = Err(anyhow::anyhow!(
2757 "Permission to run tool action denied by user"
2758 ));
2759
2760 self.tool_use.insert_tool_output(
2761 tool_use_id.clone(),
2762 tool_name,
2763 err,
2764 self.configured_model.as_ref(),
2765 );
2766 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2767 }
2768}
2769
2770#[derive(Debug, Clone, Error)]
2771pub enum ThreadError {
2772 #[error("Payment required")]
2773 PaymentRequired,
2774 #[error("Model request limit reached")]
2775 ModelRequestLimitReached { plan: Plan },
2776 #[error("Message {header}: {message}")]
2777 Message {
2778 header: SharedString,
2779 message: SharedString,
2780 },
2781}
2782
2783#[derive(Debug, Clone)]
2784pub enum ThreadEvent {
2785 ShowError(ThreadError),
2786 StreamedCompletion,
2787 ReceivedTextChunk,
2788 NewRequest,
2789 StreamedAssistantText(MessageId, String),
2790 StreamedAssistantThinking(MessageId, String),
2791 StreamedToolUse {
2792 tool_use_id: LanguageModelToolUseId,
2793 ui_text: Arc<str>,
2794 input: serde_json::Value,
2795 },
2796 MissingToolUse {
2797 tool_use_id: LanguageModelToolUseId,
2798 ui_text: Arc<str>,
2799 },
2800 InvalidToolInput {
2801 tool_use_id: LanguageModelToolUseId,
2802 ui_text: Arc<str>,
2803 invalid_input_json: Arc<str>,
2804 },
2805 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2806 MessageAdded(MessageId),
2807 MessageEdited(MessageId),
2808 MessageDeleted(MessageId),
2809 SummaryGenerated,
2810 SummaryChanged,
2811 UsePendingTools {
2812 tool_uses: Vec<PendingToolUse>,
2813 },
2814 ToolFinished {
2815 #[allow(unused)]
2816 tool_use_id: LanguageModelToolUseId,
2817 /// The pending tool use that corresponds to this tool.
2818 pending_tool_use: Option<PendingToolUse>,
2819 },
2820 CheckpointChanged,
2821 ToolConfirmationNeeded,
2822 CancelEditing,
2823 CompletionCanceled,
2824}
2825
2826impl EventEmitter<ThreadEvent> for Thread {}
2827
2828struct PendingCompletion {
2829 id: usize,
2830 queue_state: QueueState,
2831 _task: Task<()>,
2832}
2833
2834#[cfg(test)]
2835mod tests {
2836 use super::*;
2837 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2838 use agent_settings::{AgentSettings, LanguageModelParameters};
2839 use assistant_tool::ToolRegistry;
2840 use editor::EditorSettings;
2841 use gpui::TestAppContext;
2842 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2843 use project::{FakeFs, Project};
2844 use prompt_store::PromptBuilder;
2845 use serde_json::json;
2846 use settings::{Settings, SettingsStore};
2847 use std::sync::Arc;
2848 use theme::ThemeSettings;
2849 use util::path;
2850 use workspace::Workspace;
2851
2852 #[gpui::test]
2853 async fn test_message_with_context(cx: &mut TestAppContext) {
2854 init_test_settings(cx);
2855
2856 let project = create_test_project(
2857 cx,
2858 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2859 )
2860 .await;
2861
2862 let (_workspace, _thread_store, thread, context_store, model) =
2863 setup_test_environment(cx, project.clone()).await;
2864
2865 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2866 .await
2867 .unwrap();
2868
2869 let context =
2870 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2871 let loaded_context = cx
2872 .update(|cx| load_context(vec![context], &project, &None, cx))
2873 .await;
2874
2875 // Insert user message with context
2876 let message_id = thread.update(cx, |thread, cx| {
2877 thread.insert_user_message(
2878 "Please explain this code",
2879 loaded_context,
2880 None,
2881 Vec::new(),
2882 cx,
2883 )
2884 });
2885
2886 // Check content and context in message object
2887 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2888
2889 // Use different path format strings based on platform for the test
2890 #[cfg(windows)]
2891 let path_part = r"test\code.rs";
2892 #[cfg(not(windows))]
2893 let path_part = "test/code.rs";
2894
2895 let expected_context = format!(
2896 r#"
2897<context>
2898The following items were attached by the user. They are up-to-date and don't need to be re-read.
2899
2900<files>
2901```rs {path_part}
2902fn main() {{
2903 println!("Hello, world!");
2904}}
2905```
2906</files>
2907</context>
2908"#
2909 );
2910
2911 assert_eq!(message.role, Role::User);
2912 assert_eq!(message.segments.len(), 1);
2913 assert_eq!(
2914 message.segments[0],
2915 MessageSegment::Text("Please explain this code".to_string())
2916 );
2917 assert_eq!(message.loaded_context.text, expected_context);
2918
2919 // Check message in request
2920 let request = thread.update(cx, |thread, cx| {
2921 thread.to_completion_request(model.clone(), cx)
2922 });
2923
2924 assert_eq!(request.messages.len(), 2);
2925 let expected_full_message = format!("{}Please explain this code", expected_context);
2926 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2927 }
2928
2929 #[gpui::test]
2930 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2931 init_test_settings(cx);
2932
2933 let project = create_test_project(
2934 cx,
2935 json!({
2936 "file1.rs": "fn function1() {}\n",
2937 "file2.rs": "fn function2() {}\n",
2938 "file3.rs": "fn function3() {}\n",
2939 "file4.rs": "fn function4() {}\n",
2940 }),
2941 )
2942 .await;
2943
2944 let (_, _thread_store, thread, context_store, model) =
2945 setup_test_environment(cx, project.clone()).await;
2946
2947 // First message with context 1
2948 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2949 .await
2950 .unwrap();
2951 let new_contexts = context_store.update(cx, |store, cx| {
2952 store.new_context_for_thread(thread.read(cx), None)
2953 });
2954 assert_eq!(new_contexts.len(), 1);
2955 let loaded_context = cx
2956 .update(|cx| load_context(new_contexts, &project, &None, cx))
2957 .await;
2958 let message1_id = thread.update(cx, |thread, cx| {
2959 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2960 });
2961
2962 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2963 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2964 .await
2965 .unwrap();
2966 let new_contexts = context_store.update(cx, |store, cx| {
2967 store.new_context_for_thread(thread.read(cx), None)
2968 });
2969 assert_eq!(new_contexts.len(), 1);
2970 let loaded_context = cx
2971 .update(|cx| load_context(new_contexts, &project, &None, cx))
2972 .await;
2973 let message2_id = thread.update(cx, |thread, cx| {
2974 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2975 });
2976
2977 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2978 //
2979 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2980 .await
2981 .unwrap();
2982 let new_contexts = context_store.update(cx, |store, cx| {
2983 store.new_context_for_thread(thread.read(cx), None)
2984 });
2985 assert_eq!(new_contexts.len(), 1);
2986 let loaded_context = cx
2987 .update(|cx| load_context(new_contexts, &project, &None, cx))
2988 .await;
2989 let message3_id = thread.update(cx, |thread, cx| {
2990 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2991 });
2992
2993 // Check what contexts are included in each message
2994 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2995 (
2996 thread.message(message1_id).unwrap().clone(),
2997 thread.message(message2_id).unwrap().clone(),
2998 thread.message(message3_id).unwrap().clone(),
2999 )
3000 });
3001
3002 // First message should include context 1
3003 assert!(message1.loaded_context.text.contains("file1.rs"));
3004
3005 // Second message should include only context 2 (not 1)
3006 assert!(!message2.loaded_context.text.contains("file1.rs"));
3007 assert!(message2.loaded_context.text.contains("file2.rs"));
3008
3009 // Third message should include only context 3 (not 1 or 2)
3010 assert!(!message3.loaded_context.text.contains("file1.rs"));
3011 assert!(!message3.loaded_context.text.contains("file2.rs"));
3012 assert!(message3.loaded_context.text.contains("file3.rs"));
3013
3014 // Check entire request to make sure all contexts are properly included
3015 let request = thread.update(cx, |thread, cx| {
3016 thread.to_completion_request(model.clone(), cx)
3017 });
3018
3019 // The request should contain all 3 messages
3020 assert_eq!(request.messages.len(), 4);
3021
3022 // Check that the contexts are properly formatted in each message
3023 assert!(request.messages[1].string_contents().contains("file1.rs"));
3024 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3025 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3026
3027 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3028 assert!(request.messages[2].string_contents().contains("file2.rs"));
3029 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3030
3031 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3032 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3033 assert!(request.messages[3].string_contents().contains("file3.rs"));
3034
3035 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3036 .await
3037 .unwrap();
3038 let new_contexts = context_store.update(cx, |store, cx| {
3039 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3040 });
3041 assert_eq!(new_contexts.len(), 3);
3042 let loaded_context = cx
3043 .update(|cx| load_context(new_contexts, &project, &None, cx))
3044 .await
3045 .loaded_context;
3046
3047 assert!(!loaded_context.text.contains("file1.rs"));
3048 assert!(loaded_context.text.contains("file2.rs"));
3049 assert!(loaded_context.text.contains("file3.rs"));
3050 assert!(loaded_context.text.contains("file4.rs"));
3051
3052 let new_contexts = context_store.update(cx, |store, cx| {
3053 // Remove file4.rs
3054 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3055 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3056 });
3057 assert_eq!(new_contexts.len(), 2);
3058 let loaded_context = cx
3059 .update(|cx| load_context(new_contexts, &project, &None, cx))
3060 .await
3061 .loaded_context;
3062
3063 assert!(!loaded_context.text.contains("file1.rs"));
3064 assert!(loaded_context.text.contains("file2.rs"));
3065 assert!(loaded_context.text.contains("file3.rs"));
3066 assert!(!loaded_context.text.contains("file4.rs"));
3067
3068 let new_contexts = context_store.update(cx, |store, cx| {
3069 // Remove file3.rs
3070 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3071 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3072 });
3073 assert_eq!(new_contexts.len(), 1);
3074 let loaded_context = cx
3075 .update(|cx| load_context(new_contexts, &project, &None, cx))
3076 .await
3077 .loaded_context;
3078
3079 assert!(!loaded_context.text.contains("file1.rs"));
3080 assert!(loaded_context.text.contains("file2.rs"));
3081 assert!(!loaded_context.text.contains("file3.rs"));
3082 assert!(!loaded_context.text.contains("file4.rs"));
3083 }
3084
3085 #[gpui::test]
3086 async fn test_message_without_files(cx: &mut TestAppContext) {
3087 init_test_settings(cx);
3088
3089 let project = create_test_project(
3090 cx,
3091 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3092 )
3093 .await;
3094
3095 let (_, _thread_store, thread, _context_store, model) =
3096 setup_test_environment(cx, project.clone()).await;
3097
3098 // Insert user message without any context (empty context vector)
3099 let message_id = thread.update(cx, |thread, cx| {
3100 thread.insert_user_message(
3101 "What is the best way to learn Rust?",
3102 ContextLoadResult::default(),
3103 None,
3104 Vec::new(),
3105 cx,
3106 )
3107 });
3108
3109 // Check content and context in message object
3110 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3111
3112 // Context should be empty when no files are included
3113 assert_eq!(message.role, Role::User);
3114 assert_eq!(message.segments.len(), 1);
3115 assert_eq!(
3116 message.segments[0],
3117 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3118 );
3119 assert_eq!(message.loaded_context.text, "");
3120
3121 // Check message in request
3122 let request = thread.update(cx, |thread, cx| {
3123 thread.to_completion_request(model.clone(), cx)
3124 });
3125
3126 assert_eq!(request.messages.len(), 2);
3127 assert_eq!(
3128 request.messages[1].string_contents(),
3129 "What is the best way to learn Rust?"
3130 );
3131
3132 // Add second message, also without context
3133 let message2_id = thread.update(cx, |thread, cx| {
3134 thread.insert_user_message(
3135 "Are there any good books?",
3136 ContextLoadResult::default(),
3137 None,
3138 Vec::new(),
3139 cx,
3140 )
3141 });
3142
3143 let message2 =
3144 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3145 assert_eq!(message2.loaded_context.text, "");
3146
3147 // Check that both messages appear in the request
3148 let request = thread.update(cx, |thread, cx| {
3149 thread.to_completion_request(model.clone(), cx)
3150 });
3151
3152 assert_eq!(request.messages.len(), 3);
3153 assert_eq!(
3154 request.messages[1].string_contents(),
3155 "What is the best way to learn Rust?"
3156 );
3157 assert_eq!(
3158 request.messages[2].string_contents(),
3159 "Are there any good books?"
3160 );
3161 }
3162
3163 #[gpui::test]
3164 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3165 init_test_settings(cx);
3166
3167 let project = create_test_project(
3168 cx,
3169 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3170 )
3171 .await;
3172
3173 let (_workspace, _thread_store, thread, context_store, model) =
3174 setup_test_environment(cx, project.clone()).await;
3175
3176 // Open buffer and add it to context
3177 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3178 .await
3179 .unwrap();
3180
3181 let context =
3182 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3183 let loaded_context = cx
3184 .update(|cx| load_context(vec![context], &project, &None, cx))
3185 .await;
3186
3187 // Insert user message with the buffer as context
3188 thread.update(cx, |thread, cx| {
3189 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3190 });
3191
3192 // Create a request and check that it doesn't have a stale buffer warning yet
3193 let initial_request = thread.update(cx, |thread, cx| {
3194 thread.to_completion_request(model.clone(), cx)
3195 });
3196
3197 // Make sure we don't have a stale file warning yet
3198 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3199 msg.string_contents()
3200 .contains("These files changed since last read:")
3201 });
3202 assert!(
3203 !has_stale_warning,
3204 "Should not have stale buffer warning before buffer is modified"
3205 );
3206
3207 // Modify the buffer
3208 buffer.update(cx, |buffer, cx| {
3209 // Find a position at the end of line 1
3210 buffer.edit(
3211 [(1..1, "\n println!(\"Added a new line\");\n")],
3212 None,
3213 cx,
3214 );
3215 });
3216
3217 // Insert another user message without context
3218 thread.update(cx, |thread, cx| {
3219 thread.insert_user_message(
3220 "What does the code do now?",
3221 ContextLoadResult::default(),
3222 None,
3223 Vec::new(),
3224 cx,
3225 )
3226 });
3227
3228 // Create a new request and check for the stale buffer warning
3229 let new_request = thread.update(cx, |thread, cx| {
3230 thread.to_completion_request(model.clone(), cx)
3231 });
3232
3233 // We should have a stale file warning as the last message
3234 let last_message = new_request
3235 .messages
3236 .last()
3237 .expect("Request should have messages");
3238
3239 // The last message should be the stale buffer notification
3240 assert_eq!(last_message.role, Role::User);
3241
3242 // Check the exact content of the message
3243 let expected_content = "These files changed since last read:\n- code.rs\n";
3244 assert_eq!(
3245 last_message.string_contents(),
3246 expected_content,
3247 "Last message should be exactly the stale buffer notification"
3248 );
3249 }
3250
3251 #[gpui::test]
3252 async fn test_temperature_setting(cx: &mut TestAppContext) {
3253 init_test_settings(cx);
3254
3255 let project = create_test_project(
3256 cx,
3257 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3258 )
3259 .await;
3260
3261 let (_workspace, _thread_store, thread, _context_store, model) =
3262 setup_test_environment(cx, project.clone()).await;
3263
3264 // Both model and provider
3265 cx.update(|cx| {
3266 AgentSettings::override_global(
3267 AgentSettings {
3268 model_parameters: vec![LanguageModelParameters {
3269 provider: Some(model.provider_id().0.to_string().into()),
3270 model: Some(model.id().0.clone()),
3271 temperature: Some(0.66),
3272 }],
3273 ..AgentSettings::get_global(cx).clone()
3274 },
3275 cx,
3276 );
3277 });
3278
3279 let request = thread.update(cx, |thread, cx| {
3280 thread.to_completion_request(model.clone(), cx)
3281 });
3282 assert_eq!(request.temperature, Some(0.66));
3283
3284 // Only model
3285 cx.update(|cx| {
3286 AgentSettings::override_global(
3287 AgentSettings {
3288 model_parameters: vec![LanguageModelParameters {
3289 provider: None,
3290 model: Some(model.id().0.clone()),
3291 temperature: Some(0.66),
3292 }],
3293 ..AgentSettings::get_global(cx).clone()
3294 },
3295 cx,
3296 );
3297 });
3298
3299 let request = thread.update(cx, |thread, cx| {
3300 thread.to_completion_request(model.clone(), cx)
3301 });
3302 assert_eq!(request.temperature, Some(0.66));
3303
3304 // Only provider
3305 cx.update(|cx| {
3306 AgentSettings::override_global(
3307 AgentSettings {
3308 model_parameters: vec![LanguageModelParameters {
3309 provider: Some(model.provider_id().0.to_string().into()),
3310 model: None,
3311 temperature: Some(0.66),
3312 }],
3313 ..AgentSettings::get_global(cx).clone()
3314 },
3315 cx,
3316 );
3317 });
3318
3319 let request = thread.update(cx, |thread, cx| {
3320 thread.to_completion_request(model.clone(), cx)
3321 });
3322 assert_eq!(request.temperature, Some(0.66));
3323
3324 // Same model name, different provider
3325 cx.update(|cx| {
3326 AgentSettings::override_global(
3327 AgentSettings {
3328 model_parameters: vec![LanguageModelParameters {
3329 provider: Some("anthropic".into()),
3330 model: Some(model.id().0.clone()),
3331 temperature: Some(0.66),
3332 }],
3333 ..AgentSettings::get_global(cx).clone()
3334 },
3335 cx,
3336 );
3337 });
3338
3339 let request = thread.update(cx, |thread, cx| {
3340 thread.to_completion_request(model.clone(), cx)
3341 });
3342 assert_eq!(request.temperature, None);
3343 }
3344
3345 #[gpui::test]
3346 async fn test_thread_summary(cx: &mut TestAppContext) {
3347 init_test_settings(cx);
3348
3349 let project = create_test_project(cx, json!({})).await;
3350
3351 let (_, _thread_store, thread, _context_store, model) =
3352 setup_test_environment(cx, project.clone()).await;
3353
3354 // Initial state should be pending
3355 thread.read_with(cx, |thread, _| {
3356 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3357 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3358 });
3359
3360 // Manually setting the summary should not be allowed in this state
3361 thread.update(cx, |thread, cx| {
3362 thread.set_summary("This should not work", cx);
3363 });
3364
3365 thread.read_with(cx, |thread, _| {
3366 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3367 });
3368
3369 // Send a message
3370 thread.update(cx, |thread, cx| {
3371 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3372 thread.send_to_model(model.clone(), None, cx);
3373 });
3374
3375 let fake_model = model.as_fake();
3376 simulate_successful_response(&fake_model, cx);
3377
3378 // Should start generating summary when there are >= 2 messages
3379 thread.read_with(cx, |thread, _| {
3380 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3381 });
3382
3383 // Should not be able to set the summary while generating
3384 thread.update(cx, |thread, cx| {
3385 thread.set_summary("This should not work either", cx);
3386 });
3387
3388 thread.read_with(cx, |thread, _| {
3389 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3390 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3391 });
3392
3393 cx.run_until_parked();
3394 fake_model.stream_last_completion_response("Brief".into());
3395 fake_model.stream_last_completion_response(" Introduction".into());
3396 fake_model.end_last_completion_stream();
3397 cx.run_until_parked();
3398
3399 // Summary should be set
3400 thread.read_with(cx, |thread, _| {
3401 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3402 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3403 });
3404
3405 // Now we should be able to set a summary
3406 thread.update(cx, |thread, cx| {
3407 thread.set_summary("Brief Intro", cx);
3408 });
3409
3410 thread.read_with(cx, |thread, _| {
3411 assert_eq!(thread.summary().or_default(), "Brief Intro");
3412 });
3413
3414 // Test setting an empty summary (should default to DEFAULT)
3415 thread.update(cx, |thread, cx| {
3416 thread.set_summary("", cx);
3417 });
3418
3419 thread.read_with(cx, |thread, _| {
3420 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3421 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3422 });
3423 }
3424
3425 #[gpui::test]
3426 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3427 init_test_settings(cx);
3428
3429 let project = create_test_project(cx, json!({})).await;
3430
3431 let (_, _thread_store, thread, _context_store, model) =
3432 setup_test_environment(cx, project.clone()).await;
3433
3434 test_summarize_error(&model, &thread, cx);
3435
3436 // Now we should be able to set a summary
3437 thread.update(cx, |thread, cx| {
3438 thread.set_summary("Brief Intro", cx);
3439 });
3440
3441 thread.read_with(cx, |thread, _| {
3442 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3443 assert_eq!(thread.summary().or_default(), "Brief Intro");
3444 });
3445 }
3446
3447 #[gpui::test]
3448 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3449 init_test_settings(cx);
3450
3451 let project = create_test_project(cx, json!({})).await;
3452
3453 let (_, _thread_store, thread, _context_store, model) =
3454 setup_test_environment(cx, project.clone()).await;
3455
3456 test_summarize_error(&model, &thread, cx);
3457
3458 // Sending another message should not trigger another summarize request
3459 thread.update(cx, |thread, cx| {
3460 thread.insert_user_message(
3461 "How are you?",
3462 ContextLoadResult::default(),
3463 None,
3464 vec![],
3465 cx,
3466 );
3467 thread.send_to_model(model.clone(), None, cx);
3468 });
3469
3470 let fake_model = model.as_fake();
3471 simulate_successful_response(&fake_model, cx);
3472
3473 thread.read_with(cx, |thread, _| {
3474 // State is still Error, not Generating
3475 assert!(matches!(thread.summary(), ThreadSummary::Error));
3476 });
3477
3478 // But the summarize request can be invoked manually
3479 thread.update(cx, |thread, cx| {
3480 thread.summarize(cx);
3481 });
3482
3483 thread.read_with(cx, |thread, _| {
3484 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3485 });
3486
3487 cx.run_until_parked();
3488 fake_model.stream_last_completion_response("A successful summary".into());
3489 fake_model.end_last_completion_stream();
3490 cx.run_until_parked();
3491
3492 thread.read_with(cx, |thread, _| {
3493 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3494 assert_eq!(thread.summary().or_default(), "A successful summary");
3495 });
3496 }
3497
3498 fn test_summarize_error(
3499 model: &Arc<dyn LanguageModel>,
3500 thread: &Entity<Thread>,
3501 cx: &mut TestAppContext,
3502 ) {
3503 thread.update(cx, |thread, cx| {
3504 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3505 thread.send_to_model(model.clone(), None, cx);
3506 });
3507
3508 let fake_model = model.as_fake();
3509 simulate_successful_response(&fake_model, cx);
3510
3511 thread.read_with(cx, |thread, _| {
3512 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3513 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3514 });
3515
3516 // Simulate summary request ending
3517 cx.run_until_parked();
3518 fake_model.end_last_completion_stream();
3519 cx.run_until_parked();
3520
3521 // State is set to Error and default message
3522 thread.read_with(cx, |thread, _| {
3523 assert!(matches!(thread.summary(), ThreadSummary::Error));
3524 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3525 });
3526 }
3527
3528 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3529 cx.run_until_parked();
3530 fake_model.stream_last_completion_response("Assistant response".into());
3531 fake_model.end_last_completion_stream();
3532 cx.run_until_parked();
3533 }
3534
3535 fn init_test_settings(cx: &mut TestAppContext) {
3536 cx.update(|cx| {
3537 let settings_store = SettingsStore::test(cx);
3538 cx.set_global(settings_store);
3539 language::init(cx);
3540 Project::init_settings(cx);
3541 AgentSettings::register(cx);
3542 prompt_store::init(cx);
3543 thread_store::init(cx);
3544 workspace::init_settings(cx);
3545 language_model::init_settings(cx);
3546 ThemeSettings::register(cx);
3547 EditorSettings::register(cx);
3548 ToolRegistry::default_global(cx);
3549 });
3550 }
3551
3552 // Helper to create a test project with test files
3553 async fn create_test_project(
3554 cx: &mut TestAppContext,
3555 files: serde_json::Value,
3556 ) -> Entity<Project> {
3557 let fs = FakeFs::new(cx.executor());
3558 fs.insert_tree(path!("/test"), files).await;
3559 Project::test(fs, [path!("/test").as_ref()], cx).await
3560 }
3561
3562 async fn setup_test_environment(
3563 cx: &mut TestAppContext,
3564 project: Entity<Project>,
3565 ) -> (
3566 Entity<Workspace>,
3567 Entity<ThreadStore>,
3568 Entity<Thread>,
3569 Entity<ContextStore>,
3570 Arc<dyn LanguageModel>,
3571 ) {
3572 let (workspace, cx) =
3573 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3574
3575 let thread_store = cx
3576 .update(|_, cx| {
3577 ThreadStore::load(
3578 project.clone(),
3579 cx.new(|_| ToolWorkingSet::default()),
3580 None,
3581 Arc::new(PromptBuilder::new(None).unwrap()),
3582 cx,
3583 )
3584 })
3585 .await
3586 .unwrap();
3587
3588 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3589 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3590
3591 let provider = Arc::new(FakeLanguageModelProvider);
3592 let model = provider.test_model();
3593 let model: Arc<dyn LanguageModel> = Arc::new(model);
3594
3595 cx.update(|_, cx| {
3596 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3597 registry.set_default_model(
3598 Some(ConfiguredModel {
3599 provider: provider.clone(),
3600 model: model.clone(),
3601 }),
3602 cx,
3603 );
3604 registry.set_thread_summary_model(
3605 Some(ConfiguredModel {
3606 provider,
3607 model: model.clone(),
3608 }),
3609 cx,
3610 );
3611 })
3612 });
3613
3614 (workspace, thread_store, thread, context_store, model)
3615 }
3616
3617 async fn add_file_to_context(
3618 project: &Entity<Project>,
3619 context_store: &Entity<ContextStore>,
3620 path: &str,
3621 cx: &mut TestAppContext,
3622 ) -> Result<Entity<language::Buffer>> {
3623 let buffer_path = project
3624 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3625 .unwrap();
3626
3627 let buffer = project
3628 .update(cx, |project, cx| {
3629 project.open_buffer(buffer_path.clone(), cx)
3630 })
3631 .await
3632 .unwrap();
3633
3634 context_store.update(cx, |context_store, cx| {
3635 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3636 });
3637
3638 Ok(buffer)
3639 }
3640}