1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|tool_use| tool_use.status.is_error())
875 }
876
877 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
878 self.tool_use.tool_uses_for_message(id, cx)
879 }
880
881 pub fn tool_results_for_message(
882 &self,
883 assistant_message_id: MessageId,
884 ) -> Vec<&LanguageModelToolResult> {
885 self.tool_use.tool_results_for_message(assistant_message_id)
886 }
887
888 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
889 self.tool_use.tool_result(id)
890 }
891
892 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
893 match &self.tool_use.tool_result(id)?.content {
894 LanguageModelToolResultContent::Text(text) => Some(text),
895 LanguageModelToolResultContent::Image(_) => {
896 // TODO: We should display image
897 None
898 }
899 }
900 }
901
902 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
903 self.tool_use.tool_result_card(id).cloned()
904 }
905
906 /// Return tools that are both enabled and supported by the model
907 pub fn available_tools(
908 &self,
909 cx: &App,
910 model: Arc<dyn LanguageModel>,
911 ) -> Vec<LanguageModelRequestTool> {
912 if model.supports_tools() {
913 self.tools()
914 .read(cx)
915 .enabled_tools(cx)
916 .into_iter()
917 .filter_map(|tool| {
918 // Skip tools that cannot be supported
919 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
920 Some(LanguageModelRequestTool {
921 name: tool.name(),
922 description: tool.description(),
923 input_schema,
924 })
925 })
926 .collect()
927 } else {
928 Vec::default()
929 }
930 }
931
932 pub fn insert_user_message(
933 &mut self,
934 text: impl Into<String>,
935 loaded_context: ContextLoadResult,
936 git_checkpoint: Option<GitStoreCheckpoint>,
937 creases: Vec<MessageCrease>,
938 cx: &mut Context<Self>,
939 ) -> MessageId {
940 if !loaded_context.referenced_buffers.is_empty() {
941 self.action_log.update(cx, |log, cx| {
942 for buffer in loaded_context.referenced_buffers {
943 log.buffer_read(buffer, cx);
944 }
945 });
946 }
947
948 let message_id = self.insert_message(
949 Role::User,
950 vec![MessageSegment::Text(text.into())],
951 loaded_context.loaded_context,
952 creases,
953 false,
954 cx,
955 );
956
957 if let Some(git_checkpoint) = git_checkpoint {
958 self.pending_checkpoint = Some(ThreadCheckpoint {
959 message_id,
960 git_checkpoint,
961 });
962 }
963
964 self.auto_capture_telemetry(cx);
965
966 message_id
967 }
968
969 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
970 let id = self.insert_message(
971 Role::User,
972 vec![MessageSegment::Text("Continue where you left off".into())],
973 LoadedContext::default(),
974 vec![],
975 true,
976 cx,
977 );
978 self.pending_checkpoint = None;
979
980 id
981 }
982
983 pub fn insert_assistant_message(
984 &mut self,
985 segments: Vec<MessageSegment>,
986 cx: &mut Context<Self>,
987 ) -> MessageId {
988 self.insert_message(
989 Role::Assistant,
990 segments,
991 LoadedContext::default(),
992 Vec::new(),
993 false,
994 cx,
995 )
996 }
997
998 pub fn insert_message(
999 &mut self,
1000 role: Role,
1001 segments: Vec<MessageSegment>,
1002 loaded_context: LoadedContext,
1003 creases: Vec<MessageCrease>,
1004 is_hidden: bool,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 let id = self.next_message_id.post_inc();
1008 self.messages.push(Message {
1009 id,
1010 role,
1011 segments,
1012 loaded_context,
1013 creases,
1014 is_hidden,
1015 });
1016 self.touch_updated_at();
1017 cx.emit(ThreadEvent::MessageAdded(id));
1018 id
1019 }
1020
1021 pub fn edit_message(
1022 &mut self,
1023 id: MessageId,
1024 new_role: Role,
1025 new_segments: Vec<MessageSegment>,
1026 loaded_context: Option<LoadedContext>,
1027 checkpoint: Option<GitStoreCheckpoint>,
1028 cx: &mut Context<Self>,
1029 ) -> bool {
1030 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1031 return false;
1032 };
1033 message.role = new_role;
1034 message.segments = new_segments;
1035 if let Some(context) = loaded_context {
1036 message.loaded_context = context;
1037 }
1038 if let Some(git_checkpoint) = checkpoint {
1039 self.checkpoints_by_message.insert(
1040 id,
1041 ThreadCheckpoint {
1042 message_id: id,
1043 git_checkpoint,
1044 },
1045 );
1046 }
1047 self.touch_updated_at();
1048 cx.emit(ThreadEvent::MessageEdited(id));
1049 true
1050 }
1051
1052 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1053 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1054 return false;
1055 };
1056 self.messages.remove(index);
1057 self.touch_updated_at();
1058 cx.emit(ThreadEvent::MessageDeleted(id));
1059 true
1060 }
1061
1062 /// Returns the representation of this [`Thread`] in a textual form.
1063 ///
1064 /// This is the representation we use when attaching a thread as context to another thread.
1065 pub fn text(&self) -> String {
1066 let mut text = String::new();
1067
1068 for message in &self.messages {
1069 text.push_str(match message.role {
1070 language_model::Role::User => "User:",
1071 language_model::Role::Assistant => "Agent:",
1072 language_model::Role::System => "System:",
1073 });
1074 text.push('\n');
1075
1076 for segment in &message.segments {
1077 match segment {
1078 MessageSegment::Text(content) => text.push_str(content),
1079 MessageSegment::Thinking { text: content, .. } => {
1080 text.push_str(&format!("<think>{}</think>", content))
1081 }
1082 MessageSegment::RedactedThinking(_) => {}
1083 }
1084 }
1085 text.push('\n');
1086 }
1087
1088 text
1089 }
1090
1091 /// Serializes this thread into a format for storage or telemetry.
1092 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1093 let initial_project_snapshot = self.initial_project_snapshot.clone();
1094 cx.spawn(async move |this, cx| {
1095 let initial_project_snapshot = initial_project_snapshot.await;
1096 this.read_with(cx, |this, cx| SerializedThread {
1097 version: SerializedThread::VERSION.to_string(),
1098 summary: this.summary().or_default(),
1099 updated_at: this.updated_at(),
1100 messages: this
1101 .messages()
1102 .map(|message| SerializedMessage {
1103 id: message.id,
1104 role: message.role,
1105 segments: message
1106 .segments
1107 .iter()
1108 .map(|segment| match segment {
1109 MessageSegment::Text(text) => {
1110 SerializedMessageSegment::Text { text: text.clone() }
1111 }
1112 MessageSegment::Thinking { text, signature } => {
1113 SerializedMessageSegment::Thinking {
1114 text: text.clone(),
1115 signature: signature.clone(),
1116 }
1117 }
1118 MessageSegment::RedactedThinking(data) => {
1119 SerializedMessageSegment::RedactedThinking {
1120 data: data.clone(),
1121 }
1122 }
1123 })
1124 .collect(),
1125 tool_uses: this
1126 .tool_uses_for_message(message.id, cx)
1127 .into_iter()
1128 .map(|tool_use| SerializedToolUse {
1129 id: tool_use.id,
1130 name: tool_use.name,
1131 input: tool_use.input,
1132 })
1133 .collect(),
1134 tool_results: this
1135 .tool_results_for_message(message.id)
1136 .into_iter()
1137 .map(|tool_result| SerializedToolResult {
1138 tool_use_id: tool_result.tool_use_id.clone(),
1139 is_error: tool_result.is_error,
1140 content: tool_result.content.clone(),
1141 output: tool_result.output.clone(),
1142 })
1143 .collect(),
1144 context: message.loaded_context.text.clone(),
1145 creases: message
1146 .creases
1147 .iter()
1148 .map(|crease| SerializedCrease {
1149 start: crease.range.start,
1150 end: crease.range.end,
1151 icon_path: crease.metadata.icon_path.clone(),
1152 label: crease.metadata.label.clone(),
1153 })
1154 .collect(),
1155 is_hidden: message.is_hidden,
1156 })
1157 .collect(),
1158 initial_project_snapshot,
1159 cumulative_token_usage: this.cumulative_token_usage,
1160 request_token_usage: this.request_token_usage.clone(),
1161 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1162 exceeded_window_error: this.exceeded_window_error.clone(),
1163 model: this
1164 .configured_model
1165 .as_ref()
1166 .map(|model| SerializedLanguageModel {
1167 provider: model.provider.id().0.to_string(),
1168 model: model.model.id().0.to_string(),
1169 }),
1170 completion_mode: Some(this.completion_mode),
1171 tool_use_limit_reached: this.tool_use_limit_reached,
1172 })
1173 })
1174 }
1175
1176 pub fn remaining_turns(&self) -> u32 {
1177 self.remaining_turns
1178 }
1179
1180 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1181 self.remaining_turns = remaining_turns;
1182 }
1183
1184 pub fn send_to_model(
1185 &mut self,
1186 model: Arc<dyn LanguageModel>,
1187 intent: CompletionIntent,
1188 window: Option<AnyWindowHandle>,
1189 cx: &mut Context<Self>,
1190 ) {
1191 if self.remaining_turns == 0 {
1192 return;
1193 }
1194
1195 self.remaining_turns -= 1;
1196
1197 let request = self.to_completion_request(model.clone(), intent, cx);
1198
1199 self.stream_completion(request, model, window, cx);
1200 }
1201
1202 pub fn used_tools_since_last_user_message(&self) -> bool {
1203 for message in self.messages.iter().rev() {
1204 if self.tool_use.message_has_tool_results(message.id) {
1205 return true;
1206 } else if message.role == Role::User {
1207 return false;
1208 }
1209 }
1210
1211 false
1212 }
1213
1214 pub fn to_completion_request(
1215 &self,
1216 model: Arc<dyn LanguageModel>,
1217 intent: CompletionIntent,
1218 cx: &mut Context<Self>,
1219 ) -> LanguageModelRequest {
1220 let mut request = LanguageModelRequest {
1221 thread_id: Some(self.id.to_string()),
1222 prompt_id: Some(self.last_prompt_id.to_string()),
1223 intent: Some(intent),
1224 mode: None,
1225 messages: vec![],
1226 tools: Vec::new(),
1227 tool_choice: None,
1228 stop: Vec::new(),
1229 temperature: AgentSettings::temperature_for_model(&model, cx),
1230 };
1231
1232 let available_tools = self.available_tools(cx, model.clone());
1233 let available_tool_names = available_tools
1234 .iter()
1235 .map(|tool| tool.name.clone())
1236 .collect();
1237
1238 let model_context = &ModelContext {
1239 available_tools: available_tool_names,
1240 };
1241
1242 if let Some(project_context) = self.project_context.borrow().as_ref() {
1243 match self
1244 .prompt_builder
1245 .generate_assistant_system_prompt(project_context, model_context)
1246 {
1247 Err(err) => {
1248 let message = format!("{err:?}").into();
1249 log::error!("{message}");
1250 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1251 header: "Error generating system prompt".into(),
1252 message,
1253 }));
1254 }
1255 Ok(system_prompt) => {
1256 request.messages.push(LanguageModelRequestMessage {
1257 role: Role::System,
1258 content: vec![MessageContent::Text(system_prompt)],
1259 cache: true,
1260 });
1261 }
1262 }
1263 } else {
1264 let message = "Context for system prompt unexpectedly not ready.".into();
1265 log::error!("{message}");
1266 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1267 header: "Error generating system prompt".into(),
1268 message,
1269 }));
1270 }
1271
1272 let mut message_ix_to_cache = None;
1273 for message in &self.messages {
1274 let mut request_message = LanguageModelRequestMessage {
1275 role: message.role,
1276 content: Vec::new(),
1277 cache: false,
1278 };
1279
1280 message
1281 .loaded_context
1282 .add_to_request_message(&mut request_message);
1283
1284 for segment in &message.segments {
1285 match segment {
1286 MessageSegment::Text(text) => {
1287 if !text.is_empty() {
1288 request_message
1289 .content
1290 .push(MessageContent::Text(text.into()));
1291 }
1292 }
1293 MessageSegment::Thinking { text, signature } => {
1294 if !text.is_empty() {
1295 request_message.content.push(MessageContent::Thinking {
1296 text: text.into(),
1297 signature: signature.clone(),
1298 });
1299 }
1300 }
1301 MessageSegment::RedactedThinking(data) => {
1302 request_message
1303 .content
1304 .push(MessageContent::RedactedThinking(data.clone()));
1305 }
1306 };
1307 }
1308
1309 let mut cache_message = true;
1310 let mut tool_results_message = LanguageModelRequestMessage {
1311 role: Role::User,
1312 content: Vec::new(),
1313 cache: false,
1314 };
1315 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1316 if let Some(tool_result) = tool_result {
1317 request_message
1318 .content
1319 .push(MessageContent::ToolUse(tool_use.clone()));
1320 tool_results_message
1321 .content
1322 .push(MessageContent::ToolResult(LanguageModelToolResult {
1323 tool_use_id: tool_use.id.clone(),
1324 tool_name: tool_result.tool_name.clone(),
1325 is_error: tool_result.is_error,
1326 content: if tool_result.content.is_empty() {
1327 // Surprisingly, the API fails if we return an empty string here.
1328 // It thinks we are sending a tool use without a tool result.
1329 "<Tool returned an empty string>".into()
1330 } else {
1331 tool_result.content.clone()
1332 },
1333 output: None,
1334 }));
1335 } else {
1336 cache_message = false;
1337 log::debug!(
1338 "skipped tool use {:?} because it is still pending",
1339 tool_use
1340 );
1341 }
1342 }
1343
1344 if cache_message {
1345 message_ix_to_cache = Some(request.messages.len());
1346 }
1347 request.messages.push(request_message);
1348
1349 if !tool_results_message.content.is_empty() {
1350 if cache_message {
1351 message_ix_to_cache = Some(request.messages.len());
1352 }
1353 request.messages.push(tool_results_message);
1354 }
1355 }
1356
1357 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1358 if let Some(message_ix_to_cache) = message_ix_to_cache {
1359 request.messages[message_ix_to_cache].cache = true;
1360 }
1361
1362 self.attached_tracked_files_state(&mut request.messages, cx);
1363
1364 request.tools = available_tools;
1365 request.mode = if model.supports_max_mode() {
1366 Some(self.completion_mode.into())
1367 } else {
1368 Some(CompletionMode::Normal.into())
1369 };
1370
1371 request
1372 }
1373
1374 fn to_summarize_request(
1375 &self,
1376 model: &Arc<dyn LanguageModel>,
1377 intent: CompletionIntent,
1378 added_user_message: String,
1379 cx: &App,
1380 ) -> LanguageModelRequest {
1381 let mut request = LanguageModelRequest {
1382 thread_id: None,
1383 prompt_id: None,
1384 intent: Some(intent),
1385 mode: None,
1386 messages: vec![],
1387 tools: Vec::new(),
1388 tool_choice: None,
1389 stop: Vec::new(),
1390 temperature: AgentSettings::temperature_for_model(model, cx),
1391 };
1392
1393 for message in &self.messages {
1394 let mut request_message = LanguageModelRequestMessage {
1395 role: message.role,
1396 content: Vec::new(),
1397 cache: false,
1398 };
1399
1400 for segment in &message.segments {
1401 match segment {
1402 MessageSegment::Text(text) => request_message
1403 .content
1404 .push(MessageContent::Text(text.clone())),
1405 MessageSegment::Thinking { .. } => {}
1406 MessageSegment::RedactedThinking(_) => {}
1407 }
1408 }
1409
1410 if request_message.content.is_empty() {
1411 continue;
1412 }
1413
1414 request.messages.push(request_message);
1415 }
1416
1417 request.messages.push(LanguageModelRequestMessage {
1418 role: Role::User,
1419 content: vec![MessageContent::Text(added_user_message)],
1420 cache: false,
1421 });
1422
1423 request
1424 }
1425
1426 fn attached_tracked_files_state(
1427 &self,
1428 messages: &mut Vec<LanguageModelRequestMessage>,
1429 cx: &App,
1430 ) {
1431 const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
1432
1433 let mut stale_message = String::new();
1434
1435 let action_log = self.action_log.read(cx);
1436
1437 for stale_file in action_log.stale_buffers(cx) {
1438 let Some(file) = stale_file.read(cx).file() else {
1439 continue;
1440 };
1441
1442 if stale_message.is_empty() {
1443 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER.trim()).ok();
1444 }
1445
1446 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1447 }
1448
1449 let mut content = Vec::with_capacity(2);
1450
1451 if !stale_message.is_empty() {
1452 content.push(stale_message.into());
1453 }
1454
1455 if !content.is_empty() {
1456 let context_message = LanguageModelRequestMessage {
1457 role: Role::User,
1458 content,
1459 cache: false,
1460 };
1461
1462 messages.push(context_message);
1463 }
1464 }
1465
1466 pub fn stream_completion(
1467 &mut self,
1468 request: LanguageModelRequest,
1469 model: Arc<dyn LanguageModel>,
1470 window: Option<AnyWindowHandle>,
1471 cx: &mut Context<Self>,
1472 ) {
1473 self.tool_use_limit_reached = false;
1474
1475 let pending_completion_id = post_inc(&mut self.completion_count);
1476 let mut request_callback_parameters = if self.request_callback.is_some() {
1477 Some((request.clone(), Vec::new()))
1478 } else {
1479 None
1480 };
1481 let prompt_id = self.last_prompt_id.clone();
1482 let tool_use_metadata = ToolUseMetadata {
1483 model: model.clone(),
1484 thread_id: self.id.clone(),
1485 prompt_id: prompt_id.clone(),
1486 };
1487
1488 self.last_received_chunk_at = Some(Instant::now());
1489
1490 let task = cx.spawn(async move |thread, cx| {
1491 let stream_completion_future = model.stream_completion(request, &cx);
1492 let initial_token_usage =
1493 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1494 let stream_completion = async {
1495 let mut events = stream_completion_future.await?;
1496
1497 let mut stop_reason = StopReason::EndTurn;
1498 let mut current_token_usage = TokenUsage::default();
1499
1500 thread
1501 .update(cx, |_thread, cx| {
1502 cx.emit(ThreadEvent::NewRequest);
1503 })
1504 .ok();
1505
1506 let mut request_assistant_message_id = None;
1507
1508 while let Some(event) = events.next().await {
1509 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1510 response_events
1511 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1512 }
1513
1514 thread.update(cx, |thread, cx| {
1515 let event = match event {
1516 Ok(event) => event,
1517 Err(LanguageModelCompletionError::BadInputJson {
1518 id,
1519 tool_name,
1520 raw_input: invalid_input_json,
1521 json_parse_error,
1522 }) => {
1523 thread.receive_invalid_tool_json(
1524 id,
1525 tool_name,
1526 invalid_input_json,
1527 json_parse_error,
1528 window,
1529 cx,
1530 );
1531 return Ok(());
1532 }
1533 Err(LanguageModelCompletionError::Other(error)) => {
1534 return Err(error);
1535 }
1536 };
1537
1538 match event {
1539 LanguageModelCompletionEvent::StartMessage { .. } => {
1540 request_assistant_message_id =
1541 Some(thread.insert_assistant_message(
1542 vec![MessageSegment::Text(String::new())],
1543 cx,
1544 ));
1545 }
1546 LanguageModelCompletionEvent::Stop(reason) => {
1547 stop_reason = reason;
1548 }
1549 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1550 thread.update_token_usage_at_last_message(token_usage);
1551 thread.cumulative_token_usage = thread.cumulative_token_usage
1552 + token_usage
1553 - current_token_usage;
1554 current_token_usage = token_usage;
1555 }
1556 LanguageModelCompletionEvent::Text(chunk) => {
1557 thread.received_chunk();
1558
1559 cx.emit(ThreadEvent::ReceivedTextChunk);
1560 if let Some(last_message) = thread.messages.last_mut() {
1561 if last_message.role == Role::Assistant
1562 && !thread.tool_use.has_tool_results(last_message.id)
1563 {
1564 last_message.push_text(&chunk);
1565 cx.emit(ThreadEvent::StreamedAssistantText(
1566 last_message.id,
1567 chunk,
1568 ));
1569 } else {
1570 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1571 // of a new Assistant response.
1572 //
1573 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1574 // will result in duplicating the text of the chunk in the rendered Markdown.
1575 request_assistant_message_id =
1576 Some(thread.insert_assistant_message(
1577 vec![MessageSegment::Text(chunk.to_string())],
1578 cx,
1579 ));
1580 };
1581 }
1582 }
1583 LanguageModelCompletionEvent::Thinking {
1584 text: chunk,
1585 signature,
1586 } => {
1587 thread.received_chunk();
1588
1589 if let Some(last_message) = thread.messages.last_mut() {
1590 if last_message.role == Role::Assistant
1591 && !thread.tool_use.has_tool_results(last_message.id)
1592 {
1593 last_message.push_thinking(&chunk, signature);
1594 cx.emit(ThreadEvent::StreamedAssistantThinking(
1595 last_message.id,
1596 chunk,
1597 ));
1598 } else {
1599 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1600 // of a new Assistant response.
1601 //
1602 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1603 // will result in duplicating the text of the chunk in the rendered Markdown.
1604 request_assistant_message_id =
1605 Some(thread.insert_assistant_message(
1606 vec![MessageSegment::Thinking {
1607 text: chunk.to_string(),
1608 signature,
1609 }],
1610 cx,
1611 ));
1612 };
1613 }
1614 }
1615 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1616 let last_assistant_message_id = request_assistant_message_id
1617 .unwrap_or_else(|| {
1618 let new_assistant_message_id =
1619 thread.insert_assistant_message(vec![], cx);
1620 request_assistant_message_id =
1621 Some(new_assistant_message_id);
1622 new_assistant_message_id
1623 });
1624
1625 let tool_use_id = tool_use.id.clone();
1626 let streamed_input = if tool_use.is_input_complete {
1627 None
1628 } else {
1629 Some((&tool_use.input).clone())
1630 };
1631
1632 let ui_text = thread.tool_use.request_tool_use(
1633 last_assistant_message_id,
1634 tool_use,
1635 tool_use_metadata.clone(),
1636 cx,
1637 );
1638
1639 if let Some(input) = streamed_input {
1640 cx.emit(ThreadEvent::StreamedToolUse {
1641 tool_use_id,
1642 ui_text,
1643 input,
1644 });
1645 }
1646 }
1647 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1648 if let Some(completion) = thread
1649 .pending_completions
1650 .iter_mut()
1651 .find(|completion| completion.id == pending_completion_id)
1652 {
1653 match status_update {
1654 CompletionRequestStatus::Queued {
1655 position,
1656 } => {
1657 completion.queue_state = QueueState::Queued { position };
1658 }
1659 CompletionRequestStatus::Started => {
1660 completion.queue_state = QueueState::Started;
1661 }
1662 CompletionRequestStatus::Failed {
1663 code, message, request_id
1664 } => {
1665 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1666 }
1667 CompletionRequestStatus::UsageUpdated {
1668 amount, limit
1669 } => {
1670 let usage = RequestUsage { limit, amount: amount as i32 };
1671
1672 thread.last_usage = Some(usage);
1673 }
1674 CompletionRequestStatus::ToolUseLimitReached => {
1675 thread.tool_use_limit_reached = true;
1676 cx.emit(ThreadEvent::ToolUseLimitReached);
1677 }
1678 }
1679 }
1680 }
1681 }
1682
1683 thread.touch_updated_at();
1684 cx.emit(ThreadEvent::StreamedCompletion);
1685 cx.notify();
1686
1687 thread.auto_capture_telemetry(cx);
1688 Ok(())
1689 })??;
1690
1691 smol::future::yield_now().await;
1692 }
1693
1694 thread.update(cx, |thread, cx| {
1695 thread.last_received_chunk_at = None;
1696 thread
1697 .pending_completions
1698 .retain(|completion| completion.id != pending_completion_id);
1699
1700 // If there is a response without tool use, summarize the message. Otherwise,
1701 // allow two tool uses before summarizing.
1702 if matches!(thread.summary, ThreadSummary::Pending)
1703 && thread.messages.len() >= 2
1704 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1705 {
1706 thread.summarize(cx);
1707 }
1708 })?;
1709
1710 anyhow::Ok(stop_reason)
1711 };
1712
1713 let result = stream_completion.await;
1714
1715 thread
1716 .update(cx, |thread, cx| {
1717 thread.finalize_pending_checkpoint(cx);
1718 match result.as_ref() {
1719 Ok(stop_reason) => match stop_reason {
1720 StopReason::ToolUse => {
1721 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1722 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1723 }
1724 StopReason::EndTurn | StopReason::MaxTokens => {
1725 thread.project.update(cx, |project, cx| {
1726 project.set_agent_location(None, cx);
1727 });
1728 }
1729 StopReason::Refusal => {
1730 thread.project.update(cx, |project, cx| {
1731 project.set_agent_location(None, cx);
1732 });
1733
1734 // Remove the turn that was refused.
1735 //
1736 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1737 {
1738 let mut messages_to_remove = Vec::new();
1739
1740 for (ix, message) in thread.messages.iter().enumerate().rev() {
1741 messages_to_remove.push(message.id);
1742
1743 if message.role == Role::User {
1744 if ix == 0 {
1745 break;
1746 }
1747
1748 if let Some(prev_message) = thread.messages.get(ix - 1) {
1749 if prev_message.role == Role::Assistant {
1750 break;
1751 }
1752 }
1753 }
1754 }
1755
1756 for message_id in messages_to_remove {
1757 thread.delete_message(message_id, cx);
1758 }
1759 }
1760
1761 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1762 header: "Language model refusal".into(),
1763 message: "Model refused to generate content for safety reasons.".into(),
1764 }));
1765 }
1766 },
1767 Err(error) => {
1768 thread.project.update(cx, |project, cx| {
1769 project.set_agent_location(None, cx);
1770 });
1771
1772 if error.is::<PaymentRequiredError>() {
1773 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1774 } else if let Some(error) =
1775 error.downcast_ref::<ModelRequestLimitReachedError>()
1776 {
1777 cx.emit(ThreadEvent::ShowError(
1778 ThreadError::ModelRequestLimitReached { plan: error.plan },
1779 ));
1780 } else if let Some(known_error) =
1781 error.downcast_ref::<LanguageModelKnownError>()
1782 {
1783 match known_error {
1784 LanguageModelKnownError::ContextWindowLimitExceeded {
1785 tokens,
1786 } => {
1787 thread.exceeded_window_error = Some(ExceededWindowError {
1788 model_id: model.id(),
1789 token_count: *tokens,
1790 });
1791 cx.notify();
1792 }
1793 }
1794 } else {
1795 let error_message = error
1796 .chain()
1797 .map(|err| err.to_string())
1798 .collect::<Vec<_>>()
1799 .join("\n");
1800 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1801 header: "Error interacting with language model".into(),
1802 message: SharedString::from(error_message.clone()),
1803 }));
1804 }
1805
1806 thread.cancel_last_completion(window, cx);
1807 }
1808 }
1809
1810 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1811
1812 if let Some((request_callback, (request, response_events))) = thread
1813 .request_callback
1814 .as_mut()
1815 .zip(request_callback_parameters.as_ref())
1816 {
1817 request_callback(request, response_events);
1818 }
1819
1820 thread.auto_capture_telemetry(cx);
1821
1822 if let Ok(initial_usage) = initial_token_usage {
1823 let usage = thread.cumulative_token_usage - initial_usage;
1824
1825 telemetry::event!(
1826 "Assistant Thread Completion",
1827 thread_id = thread.id().to_string(),
1828 prompt_id = prompt_id,
1829 model = model.telemetry_id(),
1830 model_provider = model.provider_id().to_string(),
1831 input_tokens = usage.input_tokens,
1832 output_tokens = usage.output_tokens,
1833 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1834 cache_read_input_tokens = usage.cache_read_input_tokens,
1835 );
1836 }
1837 })
1838 .ok();
1839 });
1840
1841 self.pending_completions.push(PendingCompletion {
1842 id: pending_completion_id,
1843 queue_state: QueueState::Sending,
1844 _task: task,
1845 });
1846 }
1847
1848 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1849 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1850 println!("No thread summary model");
1851 return;
1852 };
1853
1854 if !model.provider.is_authenticated(cx) {
1855 return;
1856 }
1857
1858 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1859
1860 let request = self.to_summarize_request(
1861 &model.model,
1862 CompletionIntent::ThreadSummarization,
1863 added_user_message.into(),
1864 cx,
1865 );
1866
1867 self.summary = ThreadSummary::Generating;
1868
1869 self.pending_summary = cx.spawn(async move |this, cx| {
1870 let result = async {
1871 let mut messages = model.model.stream_completion(request, &cx).await?;
1872
1873 let mut new_summary = String::new();
1874 while let Some(event) = messages.next().await {
1875 let Ok(event) = event else {
1876 continue;
1877 };
1878 let text = match event {
1879 LanguageModelCompletionEvent::Text(text) => text,
1880 LanguageModelCompletionEvent::StatusUpdate(
1881 CompletionRequestStatus::UsageUpdated { amount, limit },
1882 ) => {
1883 this.update(cx, |thread, _cx| {
1884 thread.last_usage = Some(RequestUsage {
1885 limit,
1886 amount: amount as i32,
1887 });
1888 })?;
1889 continue;
1890 }
1891 _ => continue,
1892 };
1893
1894 let mut lines = text.lines();
1895 new_summary.extend(lines.next());
1896
1897 // Stop if the LLM generated multiple lines.
1898 if lines.next().is_some() {
1899 break;
1900 }
1901 }
1902
1903 anyhow::Ok(new_summary)
1904 }
1905 .await;
1906
1907 this.update(cx, |this, cx| {
1908 match result {
1909 Ok(new_summary) => {
1910 if new_summary.is_empty() {
1911 this.summary = ThreadSummary::Error;
1912 } else {
1913 this.summary = ThreadSummary::Ready(new_summary.into());
1914 }
1915 }
1916 Err(err) => {
1917 this.summary = ThreadSummary::Error;
1918 log::error!("Failed to generate thread summary: {}", err);
1919 }
1920 }
1921 cx.emit(ThreadEvent::SummaryGenerated);
1922 })
1923 .log_err()?;
1924
1925 Some(())
1926 });
1927 }
1928
1929 pub fn start_generating_detailed_summary_if_needed(
1930 &mut self,
1931 thread_store: WeakEntity<ThreadStore>,
1932 cx: &mut Context<Self>,
1933 ) {
1934 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1935 return;
1936 };
1937
1938 match &*self.detailed_summary_rx.borrow() {
1939 DetailedSummaryState::Generating { message_id, .. }
1940 | DetailedSummaryState::Generated { message_id, .. }
1941 if *message_id == last_message_id =>
1942 {
1943 // Already up-to-date
1944 return;
1945 }
1946 _ => {}
1947 }
1948
1949 let Some(ConfiguredModel { model, provider }) =
1950 LanguageModelRegistry::read_global(cx).thread_summary_model()
1951 else {
1952 return;
1953 };
1954
1955 if !provider.is_authenticated(cx) {
1956 return;
1957 }
1958
1959 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
1960
1961 let request = self.to_summarize_request(
1962 &model,
1963 CompletionIntent::ThreadContextSummarization,
1964 added_user_message.into(),
1965 cx,
1966 );
1967
1968 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1969 message_id: last_message_id,
1970 };
1971
1972 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1973 // be better to allow the old task to complete, but this would require logic for choosing
1974 // which result to prefer (the old task could complete after the new one, resulting in a
1975 // stale summary).
1976 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1977 let stream = model.stream_completion_text(request, &cx);
1978 let Some(mut messages) = stream.await.log_err() else {
1979 thread
1980 .update(cx, |thread, _cx| {
1981 *thread.detailed_summary_tx.borrow_mut() =
1982 DetailedSummaryState::NotGenerated;
1983 })
1984 .ok()?;
1985 return None;
1986 };
1987
1988 let mut new_detailed_summary = String::new();
1989
1990 while let Some(chunk) = messages.stream.next().await {
1991 if let Some(chunk) = chunk.log_err() {
1992 new_detailed_summary.push_str(&chunk);
1993 }
1994 }
1995
1996 thread
1997 .update(cx, |thread, _cx| {
1998 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1999 text: new_detailed_summary.into(),
2000 message_id: last_message_id,
2001 };
2002 })
2003 .ok()?;
2004
2005 // Save thread so its summary can be reused later
2006 if let Some(thread) = thread.upgrade() {
2007 if let Ok(Ok(save_task)) = cx.update(|cx| {
2008 thread_store
2009 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2010 }) {
2011 save_task.await.log_err();
2012 }
2013 }
2014
2015 Some(())
2016 });
2017 }
2018
2019 pub async fn wait_for_detailed_summary_or_text(
2020 this: &Entity<Self>,
2021 cx: &mut AsyncApp,
2022 ) -> Option<SharedString> {
2023 let mut detailed_summary_rx = this
2024 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2025 .ok()?;
2026 loop {
2027 match detailed_summary_rx.recv().await? {
2028 DetailedSummaryState::Generating { .. } => {}
2029 DetailedSummaryState::NotGenerated => {
2030 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2031 }
2032 DetailedSummaryState::Generated { text, .. } => return Some(text),
2033 }
2034 }
2035 }
2036
2037 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2038 self.detailed_summary_rx
2039 .borrow()
2040 .text()
2041 .unwrap_or_else(|| self.text().into())
2042 }
2043
2044 pub fn is_generating_detailed_summary(&self) -> bool {
2045 matches!(
2046 &*self.detailed_summary_rx.borrow(),
2047 DetailedSummaryState::Generating { .. }
2048 )
2049 }
2050
2051 pub fn use_pending_tools(
2052 &mut self,
2053 window: Option<AnyWindowHandle>,
2054 cx: &mut Context<Self>,
2055 model: Arc<dyn LanguageModel>,
2056 ) -> Vec<PendingToolUse> {
2057 self.auto_capture_telemetry(cx);
2058 let request =
2059 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2060 let pending_tool_uses = self
2061 .tool_use
2062 .pending_tool_uses()
2063 .into_iter()
2064 .filter(|tool_use| tool_use.status.is_idle())
2065 .cloned()
2066 .collect::<Vec<_>>();
2067
2068 for tool_use in pending_tool_uses.iter() {
2069 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2070 if tool.needs_confirmation(&tool_use.input, cx)
2071 && !AgentSettings::get_global(cx).always_allow_tool_actions
2072 {
2073 self.tool_use.confirm_tool_use(
2074 tool_use.id.clone(),
2075 tool_use.ui_text.clone(),
2076 tool_use.input.clone(),
2077 request.clone(),
2078 tool,
2079 );
2080 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2081 } else {
2082 self.run_tool(
2083 tool_use.id.clone(),
2084 tool_use.ui_text.clone(),
2085 tool_use.input.clone(),
2086 request.clone(),
2087 tool,
2088 model.clone(),
2089 window,
2090 cx,
2091 );
2092 }
2093 } else {
2094 self.handle_hallucinated_tool_use(
2095 tool_use.id.clone(),
2096 tool_use.name.clone(),
2097 window,
2098 cx,
2099 );
2100 }
2101 }
2102
2103 pending_tool_uses
2104 }
2105
2106 pub fn handle_hallucinated_tool_use(
2107 &mut self,
2108 tool_use_id: LanguageModelToolUseId,
2109 hallucinated_tool_name: Arc<str>,
2110 window: Option<AnyWindowHandle>,
2111 cx: &mut Context<Thread>,
2112 ) {
2113 let available_tools = self.tools.read(cx).enabled_tools(cx);
2114
2115 let tool_list = available_tools
2116 .iter()
2117 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2118 .collect::<Vec<_>>()
2119 .join("\n");
2120
2121 let error_message = format!(
2122 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2123 hallucinated_tool_name, tool_list
2124 );
2125
2126 let pending_tool_use = self.tool_use.insert_tool_output(
2127 tool_use_id.clone(),
2128 hallucinated_tool_name,
2129 Err(anyhow!("Missing tool call: {error_message}")),
2130 self.configured_model.as_ref(),
2131 );
2132
2133 cx.emit(ThreadEvent::MissingToolUse {
2134 tool_use_id: tool_use_id.clone(),
2135 ui_text: error_message.into(),
2136 });
2137
2138 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2139 }
2140
2141 pub fn receive_invalid_tool_json(
2142 &mut self,
2143 tool_use_id: LanguageModelToolUseId,
2144 tool_name: Arc<str>,
2145 invalid_json: Arc<str>,
2146 error: String,
2147 window: Option<AnyWindowHandle>,
2148 cx: &mut Context<Thread>,
2149 ) {
2150 log::error!("The model returned invalid input JSON: {invalid_json}");
2151
2152 let pending_tool_use = self.tool_use.insert_tool_output(
2153 tool_use_id.clone(),
2154 tool_name,
2155 Err(anyhow!("Error parsing input JSON: {error}")),
2156 self.configured_model.as_ref(),
2157 );
2158 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2159 pending_tool_use.ui_text.clone()
2160 } else {
2161 log::error!(
2162 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2163 );
2164 format!("Unknown tool {}", tool_use_id).into()
2165 };
2166
2167 cx.emit(ThreadEvent::InvalidToolInput {
2168 tool_use_id: tool_use_id.clone(),
2169 ui_text,
2170 invalid_input_json: invalid_json,
2171 });
2172
2173 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2174 }
2175
2176 pub fn run_tool(
2177 &mut self,
2178 tool_use_id: LanguageModelToolUseId,
2179 ui_text: impl Into<SharedString>,
2180 input: serde_json::Value,
2181 request: Arc<LanguageModelRequest>,
2182 tool: Arc<dyn Tool>,
2183 model: Arc<dyn LanguageModel>,
2184 window: Option<AnyWindowHandle>,
2185 cx: &mut Context<Thread>,
2186 ) {
2187 let task =
2188 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2189 self.tool_use
2190 .run_pending_tool(tool_use_id, ui_text.into(), task);
2191 }
2192
2193 fn spawn_tool_use(
2194 &mut self,
2195 tool_use_id: LanguageModelToolUseId,
2196 request: Arc<LanguageModelRequest>,
2197 input: serde_json::Value,
2198 tool: Arc<dyn Tool>,
2199 model: Arc<dyn LanguageModel>,
2200 window: Option<AnyWindowHandle>,
2201 cx: &mut Context<Thread>,
2202 ) -> Task<()> {
2203 let tool_name: Arc<str> = tool.name().into();
2204
2205 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2206 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2207 } else {
2208 tool.run(
2209 input,
2210 request,
2211 self.project.clone(),
2212 self.action_log.clone(),
2213 model,
2214 window,
2215 cx,
2216 )
2217 };
2218
2219 // Store the card separately if it exists
2220 if let Some(card) = tool_result.card.clone() {
2221 self.tool_use
2222 .insert_tool_result_card(tool_use_id.clone(), card);
2223 }
2224
2225 cx.spawn({
2226 async move |thread: WeakEntity<Thread>, cx| {
2227 let output = tool_result.output.await;
2228
2229 thread
2230 .update(cx, |thread, cx| {
2231 let pending_tool_use = thread.tool_use.insert_tool_output(
2232 tool_use_id.clone(),
2233 tool_name,
2234 output,
2235 thread.configured_model.as_ref(),
2236 );
2237 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2238 })
2239 .ok();
2240 }
2241 })
2242 }
2243
2244 fn tool_finished(
2245 &mut self,
2246 tool_use_id: LanguageModelToolUseId,
2247 pending_tool_use: Option<PendingToolUse>,
2248 canceled: bool,
2249 window: Option<AnyWindowHandle>,
2250 cx: &mut Context<Self>,
2251 ) {
2252 if self.all_tools_finished() {
2253 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2254 if !canceled {
2255 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2256 }
2257 self.auto_capture_telemetry(cx);
2258 }
2259 }
2260
2261 cx.emit(ThreadEvent::ToolFinished {
2262 tool_use_id,
2263 pending_tool_use,
2264 });
2265 }
2266
2267 /// Cancels the last pending completion, if there are any pending.
2268 ///
2269 /// Returns whether a completion was canceled.
2270 pub fn cancel_last_completion(
2271 &mut self,
2272 window: Option<AnyWindowHandle>,
2273 cx: &mut Context<Self>,
2274 ) -> bool {
2275 let mut canceled = self.pending_completions.pop().is_some();
2276
2277 for pending_tool_use in self.tool_use.cancel_pending() {
2278 canceled = true;
2279 self.tool_finished(
2280 pending_tool_use.id.clone(),
2281 Some(pending_tool_use),
2282 true,
2283 window,
2284 cx,
2285 );
2286 }
2287
2288 if canceled {
2289 cx.emit(ThreadEvent::CompletionCanceled);
2290
2291 // When canceled, we always want to insert the checkpoint.
2292 // (We skip over finalize_pending_checkpoint, because it
2293 // would conclude we didn't have anything to insert here.)
2294 if let Some(checkpoint) = self.pending_checkpoint.take() {
2295 self.insert_checkpoint(checkpoint, cx);
2296 }
2297 } else {
2298 self.finalize_pending_checkpoint(cx);
2299 }
2300
2301 canceled
2302 }
2303
2304 /// Signals that any in-progress editing should be canceled.
2305 ///
2306 /// This method is used to notify listeners (like ActiveThread) that
2307 /// they should cancel any editing operations.
2308 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2309 cx.emit(ThreadEvent::CancelEditing);
2310 }
2311
2312 pub fn feedback(&self) -> Option<ThreadFeedback> {
2313 self.feedback
2314 }
2315
2316 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2317 self.message_feedback.get(&message_id).copied()
2318 }
2319
2320 pub fn report_message_feedback(
2321 &mut self,
2322 message_id: MessageId,
2323 feedback: ThreadFeedback,
2324 cx: &mut Context<Self>,
2325 ) -> Task<Result<()>> {
2326 if self.message_feedback.get(&message_id) == Some(&feedback) {
2327 return Task::ready(Ok(()));
2328 }
2329
2330 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2331 let serialized_thread = self.serialize(cx);
2332 let thread_id = self.id().clone();
2333 let client = self.project.read(cx).client();
2334
2335 let enabled_tool_names: Vec<String> = self
2336 .tools()
2337 .read(cx)
2338 .enabled_tools(cx)
2339 .iter()
2340 .map(|tool| tool.name())
2341 .collect();
2342
2343 self.message_feedback.insert(message_id, feedback);
2344
2345 cx.notify();
2346
2347 let message_content = self
2348 .message(message_id)
2349 .map(|msg| msg.to_string())
2350 .unwrap_or_default();
2351
2352 cx.background_spawn(async move {
2353 let final_project_snapshot = final_project_snapshot.await;
2354 let serialized_thread = serialized_thread.await?;
2355 let thread_data =
2356 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2357
2358 let rating = match feedback {
2359 ThreadFeedback::Positive => "positive",
2360 ThreadFeedback::Negative => "negative",
2361 };
2362 telemetry::event!(
2363 "Assistant Thread Rated",
2364 rating,
2365 thread_id,
2366 enabled_tool_names,
2367 message_id = message_id.0,
2368 message_content,
2369 thread_data,
2370 final_project_snapshot
2371 );
2372 client.telemetry().flush_events().await;
2373
2374 Ok(())
2375 })
2376 }
2377
2378 pub fn report_feedback(
2379 &mut self,
2380 feedback: ThreadFeedback,
2381 cx: &mut Context<Self>,
2382 ) -> Task<Result<()>> {
2383 let last_assistant_message_id = self
2384 .messages
2385 .iter()
2386 .rev()
2387 .find(|msg| msg.role == Role::Assistant)
2388 .map(|msg| msg.id);
2389
2390 if let Some(message_id) = last_assistant_message_id {
2391 self.report_message_feedback(message_id, feedback, cx)
2392 } else {
2393 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2394 let serialized_thread = self.serialize(cx);
2395 let thread_id = self.id().clone();
2396 let client = self.project.read(cx).client();
2397 self.feedback = Some(feedback);
2398 cx.notify();
2399
2400 cx.background_spawn(async move {
2401 let final_project_snapshot = final_project_snapshot.await;
2402 let serialized_thread = serialized_thread.await?;
2403 let thread_data = serde_json::to_value(serialized_thread)
2404 .unwrap_or_else(|_| serde_json::Value::Null);
2405
2406 let rating = match feedback {
2407 ThreadFeedback::Positive => "positive",
2408 ThreadFeedback::Negative => "negative",
2409 };
2410 telemetry::event!(
2411 "Assistant Thread Rated",
2412 rating,
2413 thread_id,
2414 thread_data,
2415 final_project_snapshot
2416 );
2417 client.telemetry().flush_events().await;
2418
2419 Ok(())
2420 })
2421 }
2422 }
2423
2424 /// Create a snapshot of the current project state including git information and unsaved buffers.
2425 fn project_snapshot(
2426 project: Entity<Project>,
2427 cx: &mut Context<Self>,
2428 ) -> Task<Arc<ProjectSnapshot>> {
2429 let git_store = project.read(cx).git_store().clone();
2430 let worktree_snapshots: Vec<_> = project
2431 .read(cx)
2432 .visible_worktrees(cx)
2433 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2434 .collect();
2435
2436 cx.spawn(async move |_, cx| {
2437 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2438
2439 let mut unsaved_buffers = Vec::new();
2440 cx.update(|app_cx| {
2441 let buffer_store = project.read(app_cx).buffer_store();
2442 for buffer_handle in buffer_store.read(app_cx).buffers() {
2443 let buffer = buffer_handle.read(app_cx);
2444 if buffer.is_dirty() {
2445 if let Some(file) = buffer.file() {
2446 let path = file.path().to_string_lossy().to_string();
2447 unsaved_buffers.push(path);
2448 }
2449 }
2450 }
2451 })
2452 .ok();
2453
2454 Arc::new(ProjectSnapshot {
2455 worktree_snapshots,
2456 unsaved_buffer_paths: unsaved_buffers,
2457 timestamp: Utc::now(),
2458 })
2459 })
2460 }
2461
2462 fn worktree_snapshot(
2463 worktree: Entity<project::Worktree>,
2464 git_store: Entity<GitStore>,
2465 cx: &App,
2466 ) -> Task<WorktreeSnapshot> {
2467 cx.spawn(async move |cx| {
2468 // Get worktree path and snapshot
2469 let worktree_info = cx.update(|app_cx| {
2470 let worktree = worktree.read(app_cx);
2471 let path = worktree.abs_path().to_string_lossy().to_string();
2472 let snapshot = worktree.snapshot();
2473 (path, snapshot)
2474 });
2475
2476 let Ok((worktree_path, _snapshot)) = worktree_info else {
2477 return WorktreeSnapshot {
2478 worktree_path: String::new(),
2479 git_state: None,
2480 };
2481 };
2482
2483 let git_state = git_store
2484 .update(cx, |git_store, cx| {
2485 git_store
2486 .repositories()
2487 .values()
2488 .find(|repo| {
2489 repo.read(cx)
2490 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2491 .is_some()
2492 })
2493 .cloned()
2494 })
2495 .ok()
2496 .flatten()
2497 .map(|repo| {
2498 repo.update(cx, |repo, _| {
2499 let current_branch =
2500 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2501 repo.send_job(None, |state, _| async move {
2502 let RepositoryState::Local { backend, .. } = state else {
2503 return GitState {
2504 remote_url: None,
2505 head_sha: None,
2506 current_branch,
2507 diff: None,
2508 };
2509 };
2510
2511 let remote_url = backend.remote_url("origin");
2512 let head_sha = backend.head_sha().await;
2513 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2514
2515 GitState {
2516 remote_url,
2517 head_sha,
2518 current_branch,
2519 diff,
2520 }
2521 })
2522 })
2523 });
2524
2525 let git_state = match git_state {
2526 Some(git_state) => match git_state.ok() {
2527 Some(git_state) => git_state.await.ok(),
2528 None => None,
2529 },
2530 None => None,
2531 };
2532
2533 WorktreeSnapshot {
2534 worktree_path,
2535 git_state,
2536 }
2537 })
2538 }
2539
2540 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2541 let mut markdown = Vec::new();
2542
2543 let summary = self.summary().or_default();
2544 writeln!(markdown, "# {summary}\n")?;
2545
2546 for message in self.messages() {
2547 writeln!(
2548 markdown,
2549 "## {role}\n",
2550 role = match message.role {
2551 Role::User => "User",
2552 Role::Assistant => "Agent",
2553 Role::System => "System",
2554 }
2555 )?;
2556
2557 if !message.loaded_context.text.is_empty() {
2558 writeln!(markdown, "{}", message.loaded_context.text)?;
2559 }
2560
2561 if !message.loaded_context.images.is_empty() {
2562 writeln!(
2563 markdown,
2564 "\n{} images attached as context.\n",
2565 message.loaded_context.images.len()
2566 )?;
2567 }
2568
2569 for segment in &message.segments {
2570 match segment {
2571 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2572 MessageSegment::Thinking { text, .. } => {
2573 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2574 }
2575 MessageSegment::RedactedThinking(_) => {}
2576 }
2577 }
2578
2579 for tool_use in self.tool_uses_for_message(message.id, cx) {
2580 writeln!(
2581 markdown,
2582 "**Use Tool: {} ({})**",
2583 tool_use.name, tool_use.id
2584 )?;
2585 writeln!(markdown, "```json")?;
2586 writeln!(
2587 markdown,
2588 "{}",
2589 serde_json::to_string_pretty(&tool_use.input)?
2590 )?;
2591 writeln!(markdown, "```")?;
2592 }
2593
2594 for tool_result in self.tool_results_for_message(message.id) {
2595 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2596 if tool_result.is_error {
2597 write!(markdown, " (Error)")?;
2598 }
2599
2600 writeln!(markdown, "**\n")?;
2601 match &tool_result.content {
2602 LanguageModelToolResultContent::Text(text) => {
2603 writeln!(markdown, "{text}")?;
2604 }
2605 LanguageModelToolResultContent::Image(image) => {
2606 writeln!(markdown, "", image.source)?;
2607 }
2608 }
2609
2610 if let Some(output) = tool_result.output.as_ref() {
2611 writeln!(
2612 markdown,
2613 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2614 serde_json::to_string_pretty(output)?
2615 )?;
2616 }
2617 }
2618 }
2619
2620 Ok(String::from_utf8_lossy(&markdown).to_string())
2621 }
2622
2623 pub fn keep_edits_in_range(
2624 &mut self,
2625 buffer: Entity<language::Buffer>,
2626 buffer_range: Range<language::Anchor>,
2627 cx: &mut Context<Self>,
2628 ) {
2629 self.action_log.update(cx, |action_log, cx| {
2630 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2631 });
2632 }
2633
2634 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2635 self.action_log
2636 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2637 }
2638
2639 pub fn reject_edits_in_ranges(
2640 &mut self,
2641 buffer: Entity<language::Buffer>,
2642 buffer_ranges: Vec<Range<language::Anchor>>,
2643 cx: &mut Context<Self>,
2644 ) -> Task<Result<()>> {
2645 self.action_log.update(cx, |action_log, cx| {
2646 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2647 })
2648 }
2649
2650 pub fn action_log(&self) -> &Entity<ActionLog> {
2651 &self.action_log
2652 }
2653
2654 pub fn project(&self) -> &Entity<Project> {
2655 &self.project
2656 }
2657
2658 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2659 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2660 return;
2661 }
2662
2663 let now = Instant::now();
2664 if let Some(last) = self.last_auto_capture_at {
2665 if now.duration_since(last).as_secs() < 10 {
2666 return;
2667 }
2668 }
2669
2670 self.last_auto_capture_at = Some(now);
2671
2672 let thread_id = self.id().clone();
2673 let github_login = self
2674 .project
2675 .read(cx)
2676 .user_store()
2677 .read(cx)
2678 .current_user()
2679 .map(|user| user.github_login.clone());
2680 let client = self.project.read(cx).client();
2681 let serialize_task = self.serialize(cx);
2682
2683 cx.background_executor()
2684 .spawn(async move {
2685 if let Ok(serialized_thread) = serialize_task.await {
2686 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2687 telemetry::event!(
2688 "Agent Thread Auto-Captured",
2689 thread_id = thread_id.to_string(),
2690 thread_data = thread_data,
2691 auto_capture_reason = "tracked_user",
2692 github_login = github_login
2693 );
2694
2695 client.telemetry().flush_events().await;
2696 }
2697 }
2698 })
2699 .detach();
2700 }
2701
2702 pub fn cumulative_token_usage(&self) -> TokenUsage {
2703 self.cumulative_token_usage
2704 }
2705
2706 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2707 let Some(model) = self.configured_model.as_ref() else {
2708 return TotalTokenUsage::default();
2709 };
2710
2711 let max = model.model.max_token_count();
2712
2713 let index = self
2714 .messages
2715 .iter()
2716 .position(|msg| msg.id == message_id)
2717 .unwrap_or(0);
2718
2719 if index == 0 {
2720 return TotalTokenUsage { total: 0, max };
2721 }
2722
2723 let token_usage = &self
2724 .request_token_usage
2725 .get(index - 1)
2726 .cloned()
2727 .unwrap_or_default();
2728
2729 TotalTokenUsage {
2730 total: token_usage.total_tokens() as usize,
2731 max,
2732 }
2733 }
2734
2735 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2736 let model = self.configured_model.as_ref()?;
2737
2738 let max = model.model.max_token_count();
2739
2740 if let Some(exceeded_error) = &self.exceeded_window_error {
2741 if model.model.id() == exceeded_error.model_id {
2742 return Some(TotalTokenUsage {
2743 total: exceeded_error.token_count,
2744 max,
2745 });
2746 }
2747 }
2748
2749 let total = self
2750 .token_usage_at_last_message()
2751 .unwrap_or_default()
2752 .total_tokens() as usize;
2753
2754 Some(TotalTokenUsage { total, max })
2755 }
2756
2757 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2758 self.request_token_usage
2759 .get(self.messages.len().saturating_sub(1))
2760 .or_else(|| self.request_token_usage.last())
2761 .cloned()
2762 }
2763
2764 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2765 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2766 self.request_token_usage
2767 .resize(self.messages.len(), placeholder);
2768
2769 if let Some(last) = self.request_token_usage.last_mut() {
2770 *last = token_usage;
2771 }
2772 }
2773
2774 pub fn deny_tool_use(
2775 &mut self,
2776 tool_use_id: LanguageModelToolUseId,
2777 tool_name: Arc<str>,
2778 window: Option<AnyWindowHandle>,
2779 cx: &mut Context<Self>,
2780 ) {
2781 let err = Err(anyhow::anyhow!(
2782 "Permission to run tool action denied by user"
2783 ));
2784
2785 self.tool_use.insert_tool_output(
2786 tool_use_id.clone(),
2787 tool_name,
2788 err,
2789 self.configured_model.as_ref(),
2790 );
2791 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2792 }
2793}
2794
2795#[derive(Debug, Clone, Error)]
2796pub enum ThreadError {
2797 #[error("Payment required")]
2798 PaymentRequired,
2799 #[error("Model request limit reached")]
2800 ModelRequestLimitReached { plan: Plan },
2801 #[error("Message {header}: {message}")]
2802 Message {
2803 header: SharedString,
2804 message: SharedString,
2805 },
2806}
2807
2808#[derive(Debug, Clone)]
2809pub enum ThreadEvent {
2810 ShowError(ThreadError),
2811 StreamedCompletion,
2812 ReceivedTextChunk,
2813 NewRequest,
2814 StreamedAssistantText(MessageId, String),
2815 StreamedAssistantThinking(MessageId, String),
2816 StreamedToolUse {
2817 tool_use_id: LanguageModelToolUseId,
2818 ui_text: Arc<str>,
2819 input: serde_json::Value,
2820 },
2821 MissingToolUse {
2822 tool_use_id: LanguageModelToolUseId,
2823 ui_text: Arc<str>,
2824 },
2825 InvalidToolInput {
2826 tool_use_id: LanguageModelToolUseId,
2827 ui_text: Arc<str>,
2828 invalid_input_json: Arc<str>,
2829 },
2830 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2831 MessageAdded(MessageId),
2832 MessageEdited(MessageId),
2833 MessageDeleted(MessageId),
2834 SummaryGenerated,
2835 SummaryChanged,
2836 UsePendingTools {
2837 tool_uses: Vec<PendingToolUse>,
2838 },
2839 ToolFinished {
2840 #[allow(unused)]
2841 tool_use_id: LanguageModelToolUseId,
2842 /// The pending tool use that corresponds to this tool.
2843 pending_tool_use: Option<PendingToolUse>,
2844 },
2845 CheckpointChanged,
2846 ToolConfirmationNeeded,
2847 ToolUseLimitReached,
2848 CancelEditing,
2849 CompletionCanceled,
2850}
2851
2852impl EventEmitter<ThreadEvent> for Thread {}
2853
2854struct PendingCompletion {
2855 id: usize,
2856 queue_state: QueueState,
2857 _task: Task<()>,
2858}
2859
2860#[cfg(test)]
2861mod tests {
2862 use super::*;
2863 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2864 use agent_settings::{AgentSettings, LanguageModelParameters};
2865 use assistant_tool::ToolRegistry;
2866 use editor::EditorSettings;
2867 use gpui::TestAppContext;
2868 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2869 use project::{FakeFs, Project};
2870 use prompt_store::PromptBuilder;
2871 use serde_json::json;
2872 use settings::{Settings, SettingsStore};
2873 use std::sync::Arc;
2874 use theme::ThemeSettings;
2875 use util::path;
2876 use workspace::Workspace;
2877
2878 #[gpui::test]
2879 async fn test_message_with_context(cx: &mut TestAppContext) {
2880 init_test_settings(cx);
2881
2882 let project = create_test_project(
2883 cx,
2884 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2885 )
2886 .await;
2887
2888 let (_workspace, _thread_store, thread, context_store, model) =
2889 setup_test_environment(cx, project.clone()).await;
2890
2891 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2892 .await
2893 .unwrap();
2894
2895 let context =
2896 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2897 let loaded_context = cx
2898 .update(|cx| load_context(vec![context], &project, &None, cx))
2899 .await;
2900
2901 // Insert user message with context
2902 let message_id = thread.update(cx, |thread, cx| {
2903 thread.insert_user_message(
2904 "Please explain this code",
2905 loaded_context,
2906 None,
2907 Vec::new(),
2908 cx,
2909 )
2910 });
2911
2912 // Check content and context in message object
2913 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2914
2915 // Use different path format strings based on platform for the test
2916 #[cfg(windows)]
2917 let path_part = r"test\code.rs";
2918 #[cfg(not(windows))]
2919 let path_part = "test/code.rs";
2920
2921 let expected_context = format!(
2922 r#"
2923<context>
2924The following items were attached by the user. They are up-to-date and don't need to be re-read.
2925
2926<files>
2927```rs {path_part}
2928fn main() {{
2929 println!("Hello, world!");
2930}}
2931```
2932</files>
2933</context>
2934"#
2935 );
2936
2937 assert_eq!(message.role, Role::User);
2938 assert_eq!(message.segments.len(), 1);
2939 assert_eq!(
2940 message.segments[0],
2941 MessageSegment::Text("Please explain this code".to_string())
2942 );
2943 assert_eq!(message.loaded_context.text, expected_context);
2944
2945 // Check message in request
2946 let request = thread.update(cx, |thread, cx| {
2947 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2948 });
2949
2950 assert_eq!(request.messages.len(), 2);
2951 let expected_full_message = format!("{}Please explain this code", expected_context);
2952 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2953 }
2954
2955 #[gpui::test]
2956 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2957 init_test_settings(cx);
2958
2959 let project = create_test_project(
2960 cx,
2961 json!({
2962 "file1.rs": "fn function1() {}\n",
2963 "file2.rs": "fn function2() {}\n",
2964 "file3.rs": "fn function3() {}\n",
2965 "file4.rs": "fn function4() {}\n",
2966 }),
2967 )
2968 .await;
2969
2970 let (_, _thread_store, thread, context_store, model) =
2971 setup_test_environment(cx, project.clone()).await;
2972
2973 // First message with context 1
2974 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2975 .await
2976 .unwrap();
2977 let new_contexts = context_store.update(cx, |store, cx| {
2978 store.new_context_for_thread(thread.read(cx), None)
2979 });
2980 assert_eq!(new_contexts.len(), 1);
2981 let loaded_context = cx
2982 .update(|cx| load_context(new_contexts, &project, &None, cx))
2983 .await;
2984 let message1_id = thread.update(cx, |thread, cx| {
2985 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2986 });
2987
2988 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2989 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2990 .await
2991 .unwrap();
2992 let new_contexts = context_store.update(cx, |store, cx| {
2993 store.new_context_for_thread(thread.read(cx), None)
2994 });
2995 assert_eq!(new_contexts.len(), 1);
2996 let loaded_context = cx
2997 .update(|cx| load_context(new_contexts, &project, &None, cx))
2998 .await;
2999 let message2_id = thread.update(cx, |thread, cx| {
3000 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3001 });
3002
3003 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3004 //
3005 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3006 .await
3007 .unwrap();
3008 let new_contexts = context_store.update(cx, |store, cx| {
3009 store.new_context_for_thread(thread.read(cx), None)
3010 });
3011 assert_eq!(new_contexts.len(), 1);
3012 let loaded_context = cx
3013 .update(|cx| load_context(new_contexts, &project, &None, cx))
3014 .await;
3015 let message3_id = thread.update(cx, |thread, cx| {
3016 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3017 });
3018
3019 // Check what contexts are included in each message
3020 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3021 (
3022 thread.message(message1_id).unwrap().clone(),
3023 thread.message(message2_id).unwrap().clone(),
3024 thread.message(message3_id).unwrap().clone(),
3025 )
3026 });
3027
3028 // First message should include context 1
3029 assert!(message1.loaded_context.text.contains("file1.rs"));
3030
3031 // Second message should include only context 2 (not 1)
3032 assert!(!message2.loaded_context.text.contains("file1.rs"));
3033 assert!(message2.loaded_context.text.contains("file2.rs"));
3034
3035 // Third message should include only context 3 (not 1 or 2)
3036 assert!(!message3.loaded_context.text.contains("file1.rs"));
3037 assert!(!message3.loaded_context.text.contains("file2.rs"));
3038 assert!(message3.loaded_context.text.contains("file3.rs"));
3039
3040 // Check entire request to make sure all contexts are properly included
3041 let request = thread.update(cx, |thread, cx| {
3042 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3043 });
3044
3045 // The request should contain all 3 messages
3046 assert_eq!(request.messages.len(), 4);
3047
3048 // Check that the contexts are properly formatted in each message
3049 assert!(request.messages[1].string_contents().contains("file1.rs"));
3050 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3051 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3052
3053 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3054 assert!(request.messages[2].string_contents().contains("file2.rs"));
3055 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3056
3057 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3058 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3059 assert!(request.messages[3].string_contents().contains("file3.rs"));
3060
3061 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3062 .await
3063 .unwrap();
3064 let new_contexts = context_store.update(cx, |store, cx| {
3065 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3066 });
3067 assert_eq!(new_contexts.len(), 3);
3068 let loaded_context = cx
3069 .update(|cx| load_context(new_contexts, &project, &None, cx))
3070 .await
3071 .loaded_context;
3072
3073 assert!(!loaded_context.text.contains("file1.rs"));
3074 assert!(loaded_context.text.contains("file2.rs"));
3075 assert!(loaded_context.text.contains("file3.rs"));
3076 assert!(loaded_context.text.contains("file4.rs"));
3077
3078 let new_contexts = context_store.update(cx, |store, cx| {
3079 // Remove file4.rs
3080 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3081 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3082 });
3083 assert_eq!(new_contexts.len(), 2);
3084 let loaded_context = cx
3085 .update(|cx| load_context(new_contexts, &project, &None, cx))
3086 .await
3087 .loaded_context;
3088
3089 assert!(!loaded_context.text.contains("file1.rs"));
3090 assert!(loaded_context.text.contains("file2.rs"));
3091 assert!(loaded_context.text.contains("file3.rs"));
3092 assert!(!loaded_context.text.contains("file4.rs"));
3093
3094 let new_contexts = context_store.update(cx, |store, cx| {
3095 // Remove file3.rs
3096 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3097 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3098 });
3099 assert_eq!(new_contexts.len(), 1);
3100 let loaded_context = cx
3101 .update(|cx| load_context(new_contexts, &project, &None, cx))
3102 .await
3103 .loaded_context;
3104
3105 assert!(!loaded_context.text.contains("file1.rs"));
3106 assert!(loaded_context.text.contains("file2.rs"));
3107 assert!(!loaded_context.text.contains("file3.rs"));
3108 assert!(!loaded_context.text.contains("file4.rs"));
3109 }
3110
3111 #[gpui::test]
3112 async fn test_message_without_files(cx: &mut TestAppContext) {
3113 init_test_settings(cx);
3114
3115 let project = create_test_project(
3116 cx,
3117 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3118 )
3119 .await;
3120
3121 let (_, _thread_store, thread, _context_store, model) =
3122 setup_test_environment(cx, project.clone()).await;
3123
3124 // Insert user message without any context (empty context vector)
3125 let message_id = thread.update(cx, |thread, cx| {
3126 thread.insert_user_message(
3127 "What is the best way to learn Rust?",
3128 ContextLoadResult::default(),
3129 None,
3130 Vec::new(),
3131 cx,
3132 )
3133 });
3134
3135 // Check content and context in message object
3136 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3137
3138 // Context should be empty when no files are included
3139 assert_eq!(message.role, Role::User);
3140 assert_eq!(message.segments.len(), 1);
3141 assert_eq!(
3142 message.segments[0],
3143 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3144 );
3145 assert_eq!(message.loaded_context.text, "");
3146
3147 // Check message in request
3148 let request = thread.update(cx, |thread, cx| {
3149 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3150 });
3151
3152 assert_eq!(request.messages.len(), 2);
3153 assert_eq!(
3154 request.messages[1].string_contents(),
3155 "What is the best way to learn Rust?"
3156 );
3157
3158 // Add second message, also without context
3159 let message2_id = thread.update(cx, |thread, cx| {
3160 thread.insert_user_message(
3161 "Are there any good books?",
3162 ContextLoadResult::default(),
3163 None,
3164 Vec::new(),
3165 cx,
3166 )
3167 });
3168
3169 let message2 =
3170 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3171 assert_eq!(message2.loaded_context.text, "");
3172
3173 // Check that both messages appear in the request
3174 let request = thread.update(cx, |thread, cx| {
3175 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3176 });
3177
3178 assert_eq!(request.messages.len(), 3);
3179 assert_eq!(
3180 request.messages[1].string_contents(),
3181 "What is the best way to learn Rust?"
3182 );
3183 assert_eq!(
3184 request.messages[2].string_contents(),
3185 "Are there any good books?"
3186 );
3187 }
3188
3189 #[gpui::test]
3190 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3191 init_test_settings(cx);
3192
3193 let project = create_test_project(
3194 cx,
3195 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3196 )
3197 .await;
3198
3199 let (_workspace, _thread_store, thread, context_store, model) =
3200 setup_test_environment(cx, project.clone()).await;
3201
3202 // Open buffer and add it to context
3203 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3204 .await
3205 .unwrap();
3206
3207 let context =
3208 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3209 let loaded_context = cx
3210 .update(|cx| load_context(vec![context], &project, &None, cx))
3211 .await;
3212
3213 // Insert user message with the buffer as context
3214 thread.update(cx, |thread, cx| {
3215 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3216 });
3217
3218 // Create a request and check that it doesn't have a stale buffer warning yet
3219 let initial_request = thread.update(cx, |thread, cx| {
3220 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3221 });
3222
3223 // Make sure we don't have a stale file warning yet
3224 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3225 msg.string_contents()
3226 .contains("These files changed since last read:")
3227 });
3228 assert!(
3229 !has_stale_warning,
3230 "Should not have stale buffer warning before buffer is modified"
3231 );
3232
3233 // Modify the buffer
3234 buffer.update(cx, |buffer, cx| {
3235 // Find a position at the end of line 1
3236 buffer.edit(
3237 [(1..1, "\n println!(\"Added a new line\");\n")],
3238 None,
3239 cx,
3240 );
3241 });
3242
3243 // Insert another user message without context
3244 thread.update(cx, |thread, cx| {
3245 thread.insert_user_message(
3246 "What does the code do now?",
3247 ContextLoadResult::default(),
3248 None,
3249 Vec::new(),
3250 cx,
3251 )
3252 });
3253
3254 // Create a new request and check for the stale buffer warning
3255 let new_request = thread.update(cx, |thread, cx| {
3256 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3257 });
3258
3259 // We should have a stale file warning as the last message
3260 let last_message = new_request
3261 .messages
3262 .last()
3263 .expect("Request should have messages");
3264
3265 // The last message should be the stale buffer notification
3266 assert_eq!(last_message.role, Role::User);
3267
3268 // Check the exact content of the message
3269 let expected_content = "These files changed since last read:\n- code.rs\n";
3270 assert_eq!(
3271 last_message.string_contents(),
3272 expected_content,
3273 "Last message should be exactly the stale buffer notification"
3274 );
3275 }
3276
3277 #[gpui::test]
3278 async fn test_temperature_setting(cx: &mut TestAppContext) {
3279 init_test_settings(cx);
3280
3281 let project = create_test_project(
3282 cx,
3283 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3284 )
3285 .await;
3286
3287 let (_workspace, _thread_store, thread, _context_store, model) =
3288 setup_test_environment(cx, project.clone()).await;
3289
3290 // Both model and provider
3291 cx.update(|cx| {
3292 AgentSettings::override_global(
3293 AgentSettings {
3294 model_parameters: vec![LanguageModelParameters {
3295 provider: Some(model.provider_id().0.to_string().into()),
3296 model: Some(model.id().0.clone()),
3297 temperature: Some(0.66),
3298 }],
3299 ..AgentSettings::get_global(cx).clone()
3300 },
3301 cx,
3302 );
3303 });
3304
3305 let request = thread.update(cx, |thread, cx| {
3306 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3307 });
3308 assert_eq!(request.temperature, Some(0.66));
3309
3310 // Only model
3311 cx.update(|cx| {
3312 AgentSettings::override_global(
3313 AgentSettings {
3314 model_parameters: vec![LanguageModelParameters {
3315 provider: None,
3316 model: Some(model.id().0.clone()),
3317 temperature: Some(0.66),
3318 }],
3319 ..AgentSettings::get_global(cx).clone()
3320 },
3321 cx,
3322 );
3323 });
3324
3325 let request = thread.update(cx, |thread, cx| {
3326 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3327 });
3328 assert_eq!(request.temperature, Some(0.66));
3329
3330 // Only provider
3331 cx.update(|cx| {
3332 AgentSettings::override_global(
3333 AgentSettings {
3334 model_parameters: vec![LanguageModelParameters {
3335 provider: Some(model.provider_id().0.to_string().into()),
3336 model: None,
3337 temperature: Some(0.66),
3338 }],
3339 ..AgentSettings::get_global(cx).clone()
3340 },
3341 cx,
3342 );
3343 });
3344
3345 let request = thread.update(cx, |thread, cx| {
3346 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3347 });
3348 assert_eq!(request.temperature, Some(0.66));
3349
3350 // Same model name, different provider
3351 cx.update(|cx| {
3352 AgentSettings::override_global(
3353 AgentSettings {
3354 model_parameters: vec![LanguageModelParameters {
3355 provider: Some("anthropic".into()),
3356 model: Some(model.id().0.clone()),
3357 temperature: Some(0.66),
3358 }],
3359 ..AgentSettings::get_global(cx).clone()
3360 },
3361 cx,
3362 );
3363 });
3364
3365 let request = thread.update(cx, |thread, cx| {
3366 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3367 });
3368 assert_eq!(request.temperature, None);
3369 }
3370
3371 #[gpui::test]
3372 async fn test_thread_summary(cx: &mut TestAppContext) {
3373 init_test_settings(cx);
3374
3375 let project = create_test_project(cx, json!({})).await;
3376
3377 let (_, _thread_store, thread, _context_store, model) =
3378 setup_test_environment(cx, project.clone()).await;
3379
3380 // Initial state should be pending
3381 thread.read_with(cx, |thread, _| {
3382 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3383 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3384 });
3385
3386 // Manually setting the summary should not be allowed in this state
3387 thread.update(cx, |thread, cx| {
3388 thread.set_summary("This should not work", cx);
3389 });
3390
3391 thread.read_with(cx, |thread, _| {
3392 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3393 });
3394
3395 // Send a message
3396 thread.update(cx, |thread, cx| {
3397 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3398 thread.send_to_model(
3399 model.clone(),
3400 CompletionIntent::ThreadSummarization,
3401 None,
3402 cx,
3403 );
3404 });
3405
3406 let fake_model = model.as_fake();
3407 simulate_successful_response(&fake_model, cx);
3408
3409 // Should start generating summary when there are >= 2 messages
3410 thread.read_with(cx, |thread, _| {
3411 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3412 });
3413
3414 // Should not be able to set the summary while generating
3415 thread.update(cx, |thread, cx| {
3416 thread.set_summary("This should not work either", cx);
3417 });
3418
3419 thread.read_with(cx, |thread, _| {
3420 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3421 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3422 });
3423
3424 cx.run_until_parked();
3425 fake_model.stream_last_completion_response("Brief");
3426 fake_model.stream_last_completion_response(" Introduction");
3427 fake_model.end_last_completion_stream();
3428 cx.run_until_parked();
3429
3430 // Summary should be set
3431 thread.read_with(cx, |thread, _| {
3432 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3433 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3434 });
3435
3436 // Now we should be able to set a summary
3437 thread.update(cx, |thread, cx| {
3438 thread.set_summary("Brief Intro", cx);
3439 });
3440
3441 thread.read_with(cx, |thread, _| {
3442 assert_eq!(thread.summary().or_default(), "Brief Intro");
3443 });
3444
3445 // Test setting an empty summary (should default to DEFAULT)
3446 thread.update(cx, |thread, cx| {
3447 thread.set_summary("", cx);
3448 });
3449
3450 thread.read_with(cx, |thread, _| {
3451 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3452 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3453 });
3454 }
3455
3456 #[gpui::test]
3457 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3458 init_test_settings(cx);
3459
3460 let project = create_test_project(cx, json!({})).await;
3461
3462 let (_, _thread_store, thread, _context_store, model) =
3463 setup_test_environment(cx, project.clone()).await;
3464
3465 test_summarize_error(&model, &thread, cx);
3466
3467 // Now we should be able to set a summary
3468 thread.update(cx, |thread, cx| {
3469 thread.set_summary("Brief Intro", cx);
3470 });
3471
3472 thread.read_with(cx, |thread, _| {
3473 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3474 assert_eq!(thread.summary().or_default(), "Brief Intro");
3475 });
3476 }
3477
3478 #[gpui::test]
3479 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3480 init_test_settings(cx);
3481
3482 let project = create_test_project(cx, json!({})).await;
3483
3484 let (_, _thread_store, thread, _context_store, model) =
3485 setup_test_environment(cx, project.clone()).await;
3486
3487 test_summarize_error(&model, &thread, cx);
3488
3489 // Sending another message should not trigger another summarize request
3490 thread.update(cx, |thread, cx| {
3491 thread.insert_user_message(
3492 "How are you?",
3493 ContextLoadResult::default(),
3494 None,
3495 vec![],
3496 cx,
3497 );
3498 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3499 });
3500
3501 let fake_model = model.as_fake();
3502 simulate_successful_response(&fake_model, cx);
3503
3504 thread.read_with(cx, |thread, _| {
3505 // State is still Error, not Generating
3506 assert!(matches!(thread.summary(), ThreadSummary::Error));
3507 });
3508
3509 // But the summarize request can be invoked manually
3510 thread.update(cx, |thread, cx| {
3511 thread.summarize(cx);
3512 });
3513
3514 thread.read_with(cx, |thread, _| {
3515 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3516 });
3517
3518 cx.run_until_parked();
3519 fake_model.stream_last_completion_response("A successful summary");
3520 fake_model.end_last_completion_stream();
3521 cx.run_until_parked();
3522
3523 thread.read_with(cx, |thread, _| {
3524 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3525 assert_eq!(thread.summary().or_default(), "A successful summary");
3526 });
3527 }
3528
3529 fn test_summarize_error(
3530 model: &Arc<dyn LanguageModel>,
3531 thread: &Entity<Thread>,
3532 cx: &mut TestAppContext,
3533 ) {
3534 thread.update(cx, |thread, cx| {
3535 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3536 thread.send_to_model(
3537 model.clone(),
3538 CompletionIntent::ThreadSummarization,
3539 None,
3540 cx,
3541 );
3542 });
3543
3544 let fake_model = model.as_fake();
3545 simulate_successful_response(&fake_model, cx);
3546
3547 thread.read_with(cx, |thread, _| {
3548 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3549 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3550 });
3551
3552 // Simulate summary request ending
3553 cx.run_until_parked();
3554 fake_model.end_last_completion_stream();
3555 cx.run_until_parked();
3556
3557 // State is set to Error and default message
3558 thread.read_with(cx, |thread, _| {
3559 assert!(matches!(thread.summary(), ThreadSummary::Error));
3560 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3561 });
3562 }
3563
3564 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3565 cx.run_until_parked();
3566 fake_model.stream_last_completion_response("Assistant response");
3567 fake_model.end_last_completion_stream();
3568 cx.run_until_parked();
3569 }
3570
3571 fn init_test_settings(cx: &mut TestAppContext) {
3572 cx.update(|cx| {
3573 let settings_store = SettingsStore::test(cx);
3574 cx.set_global(settings_store);
3575 language::init(cx);
3576 Project::init_settings(cx);
3577 AgentSettings::register(cx);
3578 prompt_store::init(cx);
3579 thread_store::init(cx);
3580 workspace::init_settings(cx);
3581 language_model::init_settings(cx);
3582 ThemeSettings::register(cx);
3583 EditorSettings::register(cx);
3584 ToolRegistry::default_global(cx);
3585 });
3586 }
3587
3588 // Helper to create a test project with test files
3589 async fn create_test_project(
3590 cx: &mut TestAppContext,
3591 files: serde_json::Value,
3592 ) -> Entity<Project> {
3593 let fs = FakeFs::new(cx.executor());
3594 fs.insert_tree(path!("/test"), files).await;
3595 Project::test(fs, [path!("/test").as_ref()], cx).await
3596 }
3597
3598 async fn setup_test_environment(
3599 cx: &mut TestAppContext,
3600 project: Entity<Project>,
3601 ) -> (
3602 Entity<Workspace>,
3603 Entity<ThreadStore>,
3604 Entity<Thread>,
3605 Entity<ContextStore>,
3606 Arc<dyn LanguageModel>,
3607 ) {
3608 let (workspace, cx) =
3609 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3610
3611 let thread_store = cx
3612 .update(|_, cx| {
3613 ThreadStore::load(
3614 project.clone(),
3615 cx.new(|_| ToolWorkingSet::default()),
3616 None,
3617 Arc::new(PromptBuilder::new(None).unwrap()),
3618 cx,
3619 )
3620 })
3621 .await
3622 .unwrap();
3623
3624 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3625 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3626
3627 let provider = Arc::new(FakeLanguageModelProvider);
3628 let model = provider.test_model();
3629 let model: Arc<dyn LanguageModel> = Arc::new(model);
3630
3631 cx.update(|_, cx| {
3632 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3633 registry.set_default_model(
3634 Some(ConfiguredModel {
3635 provider: provider.clone(),
3636 model: model.clone(),
3637 }),
3638 cx,
3639 );
3640 registry.set_thread_summary_model(
3641 Some(ConfiguredModel {
3642 provider,
3643 model: model.clone(),
3644 }),
3645 cx,
3646 );
3647 })
3648 });
3649
3650 (workspace, thread_store, thread, context_store, model)
3651 }
3652
3653 async fn add_file_to_context(
3654 project: &Entity<Project>,
3655 context_store: &Entity<ContextStore>,
3656 path: &str,
3657 cx: &mut TestAppContext,
3658 ) -> Result<Entity<language::Buffer>> {
3659 let buffer_path = project
3660 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3661 .unwrap();
3662
3663 let buffer = project
3664 .update(cx, |project, cx| {
3665 project.open_buffer(buffer_path.clone(), cx)
3666 })
3667 .await
3668 .unwrap();
3669
3670 context_store.update(cx, |context_store, cx| {
3671 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3672 });
3673
3674 Ok(buffer)
3675 }
3676}