1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|pending_tool_use| pending_tool_use.status.is_error())
875 }
876
877 /// Returns whether any pending tool uses may perform edits
878 pub fn has_pending_edit_tool_uses(&self) -> bool {
879 self.tool_use
880 .pending_tool_uses()
881 .iter()
882 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
883 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
884 }
885
886 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
887 self.tool_use.tool_uses_for_message(id, cx)
888 }
889
890 pub fn tool_results_for_message(
891 &self,
892 assistant_message_id: MessageId,
893 ) -> Vec<&LanguageModelToolResult> {
894 self.tool_use.tool_results_for_message(assistant_message_id)
895 }
896
897 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
898 self.tool_use.tool_result(id)
899 }
900
901 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
902 match &self.tool_use.tool_result(id)?.content {
903 LanguageModelToolResultContent::Text(text) => Some(text),
904 LanguageModelToolResultContent::Image(_) => {
905 // TODO: We should display image
906 None
907 }
908 }
909 }
910
911 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
912 self.tool_use.tool_result_card(id).cloned()
913 }
914
915 /// Return tools that are both enabled and supported by the model
916 pub fn available_tools(
917 &self,
918 cx: &App,
919 model: Arc<dyn LanguageModel>,
920 ) -> Vec<LanguageModelRequestTool> {
921 if model.supports_tools() {
922 self.tools()
923 .read(cx)
924 .enabled_tools(cx)
925 .into_iter()
926 .filter_map(|tool| {
927 // Skip tools that cannot be supported
928 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
929 Some(LanguageModelRequestTool {
930 name: tool.name(),
931 description: tool.description(),
932 input_schema,
933 })
934 })
935 .collect()
936 } else {
937 Vec::default()
938 }
939 }
940
941 pub fn insert_user_message(
942 &mut self,
943 text: impl Into<String>,
944 loaded_context: ContextLoadResult,
945 git_checkpoint: Option<GitStoreCheckpoint>,
946 creases: Vec<MessageCrease>,
947 cx: &mut Context<Self>,
948 ) -> MessageId {
949 if !loaded_context.referenced_buffers.is_empty() {
950 self.action_log.update(cx, |log, cx| {
951 for buffer in loaded_context.referenced_buffers {
952 log.buffer_read(buffer, cx);
953 }
954 });
955 }
956
957 let message_id = self.insert_message(
958 Role::User,
959 vec![MessageSegment::Text(text.into())],
960 loaded_context.loaded_context,
961 creases,
962 false,
963 cx,
964 );
965
966 if let Some(git_checkpoint) = git_checkpoint {
967 self.pending_checkpoint = Some(ThreadCheckpoint {
968 message_id,
969 git_checkpoint,
970 });
971 }
972
973 self.auto_capture_telemetry(cx);
974
975 message_id
976 }
977
978 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
979 let id = self.insert_message(
980 Role::User,
981 vec![MessageSegment::Text("Continue where you left off".into())],
982 LoadedContext::default(),
983 vec![],
984 true,
985 cx,
986 );
987 self.pending_checkpoint = None;
988
989 id
990 }
991
992 pub fn insert_assistant_message(
993 &mut self,
994 segments: Vec<MessageSegment>,
995 cx: &mut Context<Self>,
996 ) -> MessageId {
997 self.insert_message(
998 Role::Assistant,
999 segments,
1000 LoadedContext::default(),
1001 Vec::new(),
1002 false,
1003 cx,
1004 )
1005 }
1006
1007 pub fn insert_message(
1008 &mut self,
1009 role: Role,
1010 segments: Vec<MessageSegment>,
1011 loaded_context: LoadedContext,
1012 creases: Vec<MessageCrease>,
1013 is_hidden: bool,
1014 cx: &mut Context<Self>,
1015 ) -> MessageId {
1016 let id = self.next_message_id.post_inc();
1017 self.messages.push(Message {
1018 id,
1019 role,
1020 segments,
1021 loaded_context,
1022 creases,
1023 is_hidden,
1024 });
1025 self.touch_updated_at();
1026 cx.emit(ThreadEvent::MessageAdded(id));
1027 id
1028 }
1029
1030 pub fn edit_message(
1031 &mut self,
1032 id: MessageId,
1033 new_role: Role,
1034 new_segments: Vec<MessageSegment>,
1035 creases: Vec<MessageCrease>,
1036 loaded_context: Option<LoadedContext>,
1037 checkpoint: Option<GitStoreCheckpoint>,
1038 cx: &mut Context<Self>,
1039 ) -> bool {
1040 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1041 return false;
1042 };
1043 message.role = new_role;
1044 message.segments = new_segments;
1045 message.creases = creases;
1046 if let Some(context) = loaded_context {
1047 message.loaded_context = context;
1048 }
1049 if let Some(git_checkpoint) = checkpoint {
1050 self.checkpoints_by_message.insert(
1051 id,
1052 ThreadCheckpoint {
1053 message_id: id,
1054 git_checkpoint,
1055 },
1056 );
1057 }
1058 self.touch_updated_at();
1059 cx.emit(ThreadEvent::MessageEdited(id));
1060 true
1061 }
1062
1063 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1064 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1065 return false;
1066 };
1067 self.messages.remove(index);
1068 self.touch_updated_at();
1069 cx.emit(ThreadEvent::MessageDeleted(id));
1070 true
1071 }
1072
1073 /// Returns the representation of this [`Thread`] in a textual form.
1074 ///
1075 /// This is the representation we use when attaching a thread as context to another thread.
1076 pub fn text(&self) -> String {
1077 let mut text = String::new();
1078
1079 for message in &self.messages {
1080 text.push_str(match message.role {
1081 language_model::Role::User => "User:",
1082 language_model::Role::Assistant => "Agent:",
1083 language_model::Role::System => "System:",
1084 });
1085 text.push('\n');
1086
1087 for segment in &message.segments {
1088 match segment {
1089 MessageSegment::Text(content) => text.push_str(content),
1090 MessageSegment::Thinking { text: content, .. } => {
1091 text.push_str(&format!("<think>{}</think>", content))
1092 }
1093 MessageSegment::RedactedThinking(_) => {}
1094 }
1095 }
1096 text.push('\n');
1097 }
1098
1099 text
1100 }
1101
1102 /// Serializes this thread into a format for storage or telemetry.
1103 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1104 let initial_project_snapshot = self.initial_project_snapshot.clone();
1105 cx.spawn(async move |this, cx| {
1106 let initial_project_snapshot = initial_project_snapshot.await;
1107 this.read_with(cx, |this, cx| SerializedThread {
1108 version: SerializedThread::VERSION.to_string(),
1109 summary: this.summary().or_default(),
1110 updated_at: this.updated_at(),
1111 messages: this
1112 .messages()
1113 .map(|message| SerializedMessage {
1114 id: message.id,
1115 role: message.role,
1116 segments: message
1117 .segments
1118 .iter()
1119 .map(|segment| match segment {
1120 MessageSegment::Text(text) => {
1121 SerializedMessageSegment::Text { text: text.clone() }
1122 }
1123 MessageSegment::Thinking { text, signature } => {
1124 SerializedMessageSegment::Thinking {
1125 text: text.clone(),
1126 signature: signature.clone(),
1127 }
1128 }
1129 MessageSegment::RedactedThinking(data) => {
1130 SerializedMessageSegment::RedactedThinking {
1131 data: data.clone(),
1132 }
1133 }
1134 })
1135 .collect(),
1136 tool_uses: this
1137 .tool_uses_for_message(message.id, cx)
1138 .into_iter()
1139 .map(|tool_use| SerializedToolUse {
1140 id: tool_use.id,
1141 name: tool_use.name,
1142 input: tool_use.input,
1143 })
1144 .collect(),
1145 tool_results: this
1146 .tool_results_for_message(message.id)
1147 .into_iter()
1148 .map(|tool_result| SerializedToolResult {
1149 tool_use_id: tool_result.tool_use_id.clone(),
1150 is_error: tool_result.is_error,
1151 content: tool_result.content.clone(),
1152 output: tool_result.output.clone(),
1153 })
1154 .collect(),
1155 context: message.loaded_context.text.clone(),
1156 creases: message
1157 .creases
1158 .iter()
1159 .map(|crease| SerializedCrease {
1160 start: crease.range.start,
1161 end: crease.range.end,
1162 icon_path: crease.metadata.icon_path.clone(),
1163 label: crease.metadata.label.clone(),
1164 })
1165 .collect(),
1166 is_hidden: message.is_hidden,
1167 })
1168 .collect(),
1169 initial_project_snapshot,
1170 cumulative_token_usage: this.cumulative_token_usage,
1171 request_token_usage: this.request_token_usage.clone(),
1172 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1173 exceeded_window_error: this.exceeded_window_error.clone(),
1174 model: this
1175 .configured_model
1176 .as_ref()
1177 .map(|model| SerializedLanguageModel {
1178 provider: model.provider.id().0.to_string(),
1179 model: model.model.id().0.to_string(),
1180 }),
1181 completion_mode: Some(this.completion_mode),
1182 tool_use_limit_reached: this.tool_use_limit_reached,
1183 })
1184 })
1185 }
1186
1187 pub fn remaining_turns(&self) -> u32 {
1188 self.remaining_turns
1189 }
1190
1191 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1192 self.remaining_turns = remaining_turns;
1193 }
1194
1195 pub fn send_to_model(
1196 &mut self,
1197 model: Arc<dyn LanguageModel>,
1198 intent: CompletionIntent,
1199 window: Option<AnyWindowHandle>,
1200 cx: &mut Context<Self>,
1201 ) {
1202 if self.remaining_turns == 0 {
1203 return;
1204 }
1205
1206 self.remaining_turns -= 1;
1207
1208 let request = self.to_completion_request(model.clone(), intent, cx);
1209
1210 self.stream_completion(request, model, window, cx);
1211 }
1212
1213 pub fn used_tools_since_last_user_message(&self) -> bool {
1214 for message in self.messages.iter().rev() {
1215 if self.tool_use.message_has_tool_results(message.id) {
1216 return true;
1217 } else if message.role == Role::User {
1218 return false;
1219 }
1220 }
1221
1222 false
1223 }
1224
1225 pub fn to_completion_request(
1226 &self,
1227 model: Arc<dyn LanguageModel>,
1228 intent: CompletionIntent,
1229 cx: &mut Context<Self>,
1230 ) -> LanguageModelRequest {
1231 let mut request = LanguageModelRequest {
1232 thread_id: Some(self.id.to_string()),
1233 prompt_id: Some(self.last_prompt_id.to_string()),
1234 intent: Some(intent),
1235 mode: None,
1236 messages: vec![],
1237 tools: Vec::new(),
1238 tool_choice: None,
1239 stop: Vec::new(),
1240 temperature: AgentSettings::temperature_for_model(&model, cx),
1241 };
1242
1243 let available_tools = self.available_tools(cx, model.clone());
1244 let available_tool_names = available_tools
1245 .iter()
1246 .map(|tool| tool.name.clone())
1247 .collect();
1248
1249 let model_context = &ModelContext {
1250 available_tools: available_tool_names,
1251 };
1252
1253 if let Some(project_context) = self.project_context.borrow().as_ref() {
1254 match self
1255 .prompt_builder
1256 .generate_assistant_system_prompt(project_context, model_context)
1257 {
1258 Err(err) => {
1259 let message = format!("{err:?}").into();
1260 log::error!("{message}");
1261 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1262 header: "Error generating system prompt".into(),
1263 message,
1264 }));
1265 }
1266 Ok(system_prompt) => {
1267 request.messages.push(LanguageModelRequestMessage {
1268 role: Role::System,
1269 content: vec![MessageContent::Text(system_prompt)],
1270 cache: true,
1271 });
1272 }
1273 }
1274 } else {
1275 let message = "Context for system prompt unexpectedly not ready.".into();
1276 log::error!("{message}");
1277 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1278 header: "Error generating system prompt".into(),
1279 message,
1280 }));
1281 }
1282
1283 let mut message_ix_to_cache = None;
1284 for message in &self.messages {
1285 let mut request_message = LanguageModelRequestMessage {
1286 role: message.role,
1287 content: Vec::new(),
1288 cache: false,
1289 };
1290
1291 message
1292 .loaded_context
1293 .add_to_request_message(&mut request_message);
1294
1295 for segment in &message.segments {
1296 match segment {
1297 MessageSegment::Text(text) => {
1298 if !text.is_empty() {
1299 request_message
1300 .content
1301 .push(MessageContent::Text(text.into()));
1302 }
1303 }
1304 MessageSegment::Thinking { text, signature } => {
1305 if !text.is_empty() {
1306 request_message.content.push(MessageContent::Thinking {
1307 text: text.into(),
1308 signature: signature.clone(),
1309 });
1310 }
1311 }
1312 MessageSegment::RedactedThinking(data) => {
1313 request_message
1314 .content
1315 .push(MessageContent::RedactedThinking(data.clone()));
1316 }
1317 };
1318 }
1319
1320 let mut cache_message = true;
1321 let mut tool_results_message = LanguageModelRequestMessage {
1322 role: Role::User,
1323 content: Vec::new(),
1324 cache: false,
1325 };
1326 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1327 if let Some(tool_result) = tool_result {
1328 request_message
1329 .content
1330 .push(MessageContent::ToolUse(tool_use.clone()));
1331 tool_results_message
1332 .content
1333 .push(MessageContent::ToolResult(LanguageModelToolResult {
1334 tool_use_id: tool_use.id.clone(),
1335 tool_name: tool_result.tool_name.clone(),
1336 is_error: tool_result.is_error,
1337 content: if tool_result.content.is_empty() {
1338 // Surprisingly, the API fails if we return an empty string here.
1339 // It thinks we are sending a tool use without a tool result.
1340 "<Tool returned an empty string>".into()
1341 } else {
1342 tool_result.content.clone()
1343 },
1344 output: None,
1345 }));
1346 } else {
1347 cache_message = false;
1348 log::debug!(
1349 "skipped tool use {:?} because it is still pending",
1350 tool_use
1351 );
1352 }
1353 }
1354
1355 if cache_message {
1356 message_ix_to_cache = Some(request.messages.len());
1357 }
1358 request.messages.push(request_message);
1359
1360 if !tool_results_message.content.is_empty() {
1361 if cache_message {
1362 message_ix_to_cache = Some(request.messages.len());
1363 }
1364 request.messages.push(tool_results_message);
1365 }
1366 }
1367
1368 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1369 if let Some(message_ix_to_cache) = message_ix_to_cache {
1370 request.messages[message_ix_to_cache].cache = true;
1371 }
1372
1373 self.attached_tracked_files_state(&mut request.messages, cx);
1374
1375 request.tools = available_tools;
1376 request.mode = if model.supports_max_mode() {
1377 Some(self.completion_mode.into())
1378 } else {
1379 Some(CompletionMode::Normal.into())
1380 };
1381
1382 request
1383 }
1384
1385 fn to_summarize_request(
1386 &self,
1387 model: &Arc<dyn LanguageModel>,
1388 intent: CompletionIntent,
1389 added_user_message: String,
1390 cx: &App,
1391 ) -> LanguageModelRequest {
1392 let mut request = LanguageModelRequest {
1393 thread_id: None,
1394 prompt_id: None,
1395 intent: Some(intent),
1396 mode: None,
1397 messages: vec![],
1398 tools: Vec::new(),
1399 tool_choice: None,
1400 stop: Vec::new(),
1401 temperature: AgentSettings::temperature_for_model(model, cx),
1402 };
1403
1404 for message in &self.messages {
1405 let mut request_message = LanguageModelRequestMessage {
1406 role: message.role,
1407 content: Vec::new(),
1408 cache: false,
1409 };
1410
1411 for segment in &message.segments {
1412 match segment {
1413 MessageSegment::Text(text) => request_message
1414 .content
1415 .push(MessageContent::Text(text.clone())),
1416 MessageSegment::Thinking { .. } => {}
1417 MessageSegment::RedactedThinking(_) => {}
1418 }
1419 }
1420
1421 if request_message.content.is_empty() {
1422 continue;
1423 }
1424
1425 request.messages.push(request_message);
1426 }
1427
1428 request.messages.push(LanguageModelRequestMessage {
1429 role: Role::User,
1430 content: vec![MessageContent::Text(added_user_message)],
1431 cache: false,
1432 });
1433
1434 request
1435 }
1436
1437 fn attached_tracked_files_state(
1438 &self,
1439 messages: &mut Vec<LanguageModelRequestMessage>,
1440 cx: &App,
1441 ) {
1442 const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
1443
1444 let mut stale_message = String::new();
1445
1446 let action_log = self.action_log.read(cx);
1447
1448 for stale_file in action_log.stale_buffers(cx) {
1449 let Some(file) = stale_file.read(cx).file() else {
1450 continue;
1451 };
1452
1453 if stale_message.is_empty() {
1454 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER.trim()).ok();
1455 }
1456
1457 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1458 }
1459
1460 let mut content = Vec::with_capacity(2);
1461
1462 if !stale_message.is_empty() {
1463 content.push(stale_message.into());
1464 }
1465
1466 if !content.is_empty() {
1467 let context_message = LanguageModelRequestMessage {
1468 role: Role::User,
1469 content,
1470 cache: false,
1471 };
1472
1473 messages.push(context_message);
1474 }
1475 }
1476
1477 pub fn stream_completion(
1478 &mut self,
1479 request: LanguageModelRequest,
1480 model: Arc<dyn LanguageModel>,
1481 window: Option<AnyWindowHandle>,
1482 cx: &mut Context<Self>,
1483 ) {
1484 self.tool_use_limit_reached = false;
1485
1486 let pending_completion_id = post_inc(&mut self.completion_count);
1487 let mut request_callback_parameters = if self.request_callback.is_some() {
1488 Some((request.clone(), Vec::new()))
1489 } else {
1490 None
1491 };
1492 let prompt_id = self.last_prompt_id.clone();
1493 let tool_use_metadata = ToolUseMetadata {
1494 model: model.clone(),
1495 thread_id: self.id.clone(),
1496 prompt_id: prompt_id.clone(),
1497 };
1498
1499 self.last_received_chunk_at = Some(Instant::now());
1500
1501 let task = cx.spawn(async move |thread, cx| {
1502 let stream_completion_future = model.stream_completion(request, &cx);
1503 let initial_token_usage =
1504 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1505 let stream_completion = async {
1506 let mut events = stream_completion_future.await?;
1507
1508 let mut stop_reason = StopReason::EndTurn;
1509 let mut current_token_usage = TokenUsage::default();
1510
1511 thread
1512 .update(cx, |_thread, cx| {
1513 cx.emit(ThreadEvent::NewRequest);
1514 })
1515 .ok();
1516
1517 let mut request_assistant_message_id = None;
1518
1519 while let Some(event) = events.next().await {
1520 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1521 response_events
1522 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1523 }
1524
1525 thread.update(cx, |thread, cx| {
1526 let event = match event {
1527 Ok(event) => event,
1528 Err(LanguageModelCompletionError::BadInputJson {
1529 id,
1530 tool_name,
1531 raw_input: invalid_input_json,
1532 json_parse_error,
1533 }) => {
1534 thread.receive_invalid_tool_json(
1535 id,
1536 tool_name,
1537 invalid_input_json,
1538 json_parse_error,
1539 window,
1540 cx,
1541 );
1542 return Ok(());
1543 }
1544 Err(LanguageModelCompletionError::Other(error)) => {
1545 return Err(error);
1546 }
1547 };
1548
1549 match event {
1550 LanguageModelCompletionEvent::StartMessage { .. } => {
1551 request_assistant_message_id =
1552 Some(thread.insert_assistant_message(
1553 vec![MessageSegment::Text(String::new())],
1554 cx,
1555 ));
1556 }
1557 LanguageModelCompletionEvent::Stop(reason) => {
1558 stop_reason = reason;
1559 }
1560 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1561 thread.update_token_usage_at_last_message(token_usage);
1562 thread.cumulative_token_usage = thread.cumulative_token_usage
1563 + token_usage
1564 - current_token_usage;
1565 current_token_usage = token_usage;
1566 }
1567 LanguageModelCompletionEvent::Text(chunk) => {
1568 thread.received_chunk();
1569
1570 cx.emit(ThreadEvent::ReceivedTextChunk);
1571 if let Some(last_message) = thread.messages.last_mut() {
1572 if last_message.role == Role::Assistant
1573 && !thread.tool_use.has_tool_results(last_message.id)
1574 {
1575 last_message.push_text(&chunk);
1576 cx.emit(ThreadEvent::StreamedAssistantText(
1577 last_message.id,
1578 chunk,
1579 ));
1580 } else {
1581 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1582 // of a new Assistant response.
1583 //
1584 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1585 // will result in duplicating the text of the chunk in the rendered Markdown.
1586 request_assistant_message_id =
1587 Some(thread.insert_assistant_message(
1588 vec![MessageSegment::Text(chunk.to_string())],
1589 cx,
1590 ));
1591 };
1592 }
1593 }
1594 LanguageModelCompletionEvent::Thinking {
1595 text: chunk,
1596 signature,
1597 } => {
1598 thread.received_chunk();
1599
1600 if let Some(last_message) = thread.messages.last_mut() {
1601 if last_message.role == Role::Assistant
1602 && !thread.tool_use.has_tool_results(last_message.id)
1603 {
1604 last_message.push_thinking(&chunk, signature);
1605 cx.emit(ThreadEvent::StreamedAssistantThinking(
1606 last_message.id,
1607 chunk,
1608 ));
1609 } else {
1610 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1611 // of a new Assistant response.
1612 //
1613 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1614 // will result in duplicating the text of the chunk in the rendered Markdown.
1615 request_assistant_message_id =
1616 Some(thread.insert_assistant_message(
1617 vec![MessageSegment::Thinking {
1618 text: chunk.to_string(),
1619 signature,
1620 }],
1621 cx,
1622 ));
1623 };
1624 }
1625 }
1626 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1627 let last_assistant_message_id = request_assistant_message_id
1628 .unwrap_or_else(|| {
1629 let new_assistant_message_id =
1630 thread.insert_assistant_message(vec![], cx);
1631 request_assistant_message_id =
1632 Some(new_assistant_message_id);
1633 new_assistant_message_id
1634 });
1635
1636 let tool_use_id = tool_use.id.clone();
1637 let streamed_input = if tool_use.is_input_complete {
1638 None
1639 } else {
1640 Some((&tool_use.input).clone())
1641 };
1642
1643 let ui_text = thread.tool_use.request_tool_use(
1644 last_assistant_message_id,
1645 tool_use,
1646 tool_use_metadata.clone(),
1647 cx,
1648 );
1649
1650 if let Some(input) = streamed_input {
1651 cx.emit(ThreadEvent::StreamedToolUse {
1652 tool_use_id,
1653 ui_text,
1654 input,
1655 });
1656 }
1657 }
1658 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1659 if let Some(completion) = thread
1660 .pending_completions
1661 .iter_mut()
1662 .find(|completion| completion.id == pending_completion_id)
1663 {
1664 match status_update {
1665 CompletionRequestStatus::Queued {
1666 position,
1667 } => {
1668 completion.queue_state = QueueState::Queued { position };
1669 }
1670 CompletionRequestStatus::Started => {
1671 completion.queue_state = QueueState::Started;
1672 }
1673 CompletionRequestStatus::Failed {
1674 code, message, request_id
1675 } => {
1676 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1677 }
1678 CompletionRequestStatus::UsageUpdated {
1679 amount, limit
1680 } => {
1681 let usage = RequestUsage { limit, amount: amount as i32 };
1682
1683 thread.last_usage = Some(usage);
1684 }
1685 CompletionRequestStatus::ToolUseLimitReached => {
1686 thread.tool_use_limit_reached = true;
1687 cx.emit(ThreadEvent::ToolUseLimitReached);
1688 }
1689 }
1690 }
1691 }
1692 }
1693
1694 thread.touch_updated_at();
1695 cx.emit(ThreadEvent::StreamedCompletion);
1696 cx.notify();
1697
1698 thread.auto_capture_telemetry(cx);
1699 Ok(())
1700 })??;
1701
1702 smol::future::yield_now().await;
1703 }
1704
1705 thread.update(cx, |thread, cx| {
1706 thread.last_received_chunk_at = None;
1707 thread
1708 .pending_completions
1709 .retain(|completion| completion.id != pending_completion_id);
1710
1711 // If there is a response without tool use, summarize the message. Otherwise,
1712 // allow two tool uses before summarizing.
1713 if matches!(thread.summary, ThreadSummary::Pending)
1714 && thread.messages.len() >= 2
1715 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1716 {
1717 thread.summarize(cx);
1718 }
1719 })?;
1720
1721 anyhow::Ok(stop_reason)
1722 };
1723
1724 let result = stream_completion.await;
1725
1726 thread
1727 .update(cx, |thread, cx| {
1728 thread.finalize_pending_checkpoint(cx);
1729 match result.as_ref() {
1730 Ok(stop_reason) => match stop_reason {
1731 StopReason::ToolUse => {
1732 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1733 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1734 }
1735 StopReason::EndTurn | StopReason::MaxTokens => {
1736 thread.project.update(cx, |project, cx| {
1737 project.set_agent_location(None, cx);
1738 });
1739 }
1740 StopReason::Refusal => {
1741 thread.project.update(cx, |project, cx| {
1742 project.set_agent_location(None, cx);
1743 });
1744
1745 // Remove the turn that was refused.
1746 //
1747 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1748 {
1749 let mut messages_to_remove = Vec::new();
1750
1751 for (ix, message) in thread.messages.iter().enumerate().rev() {
1752 messages_to_remove.push(message.id);
1753
1754 if message.role == Role::User {
1755 if ix == 0 {
1756 break;
1757 }
1758
1759 if let Some(prev_message) = thread.messages.get(ix - 1) {
1760 if prev_message.role == Role::Assistant {
1761 break;
1762 }
1763 }
1764 }
1765 }
1766
1767 for message_id in messages_to_remove {
1768 thread.delete_message(message_id, cx);
1769 }
1770 }
1771
1772 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1773 header: "Language model refusal".into(),
1774 message: "Model refused to generate content for safety reasons.".into(),
1775 }));
1776 }
1777 },
1778 Err(error) => {
1779 thread.project.update(cx, |project, cx| {
1780 project.set_agent_location(None, cx);
1781 });
1782
1783 if error.is::<PaymentRequiredError>() {
1784 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1785 } else if let Some(error) =
1786 error.downcast_ref::<ModelRequestLimitReachedError>()
1787 {
1788 cx.emit(ThreadEvent::ShowError(
1789 ThreadError::ModelRequestLimitReached { plan: error.plan },
1790 ));
1791 } else if let Some(known_error) =
1792 error.downcast_ref::<LanguageModelKnownError>()
1793 {
1794 match known_error {
1795 LanguageModelKnownError::ContextWindowLimitExceeded {
1796 tokens,
1797 } => {
1798 thread.exceeded_window_error = Some(ExceededWindowError {
1799 model_id: model.id(),
1800 token_count: *tokens,
1801 });
1802 cx.notify();
1803 }
1804 }
1805 } else {
1806 let error_message = error
1807 .chain()
1808 .map(|err| err.to_string())
1809 .collect::<Vec<_>>()
1810 .join("\n");
1811 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1812 header: "Error interacting with language model".into(),
1813 message: SharedString::from(error_message.clone()),
1814 }));
1815 }
1816
1817 thread.cancel_last_completion(window, cx);
1818 }
1819 }
1820
1821 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1822
1823 if let Some((request_callback, (request, response_events))) = thread
1824 .request_callback
1825 .as_mut()
1826 .zip(request_callback_parameters.as_ref())
1827 {
1828 request_callback(request, response_events);
1829 }
1830
1831 thread.auto_capture_telemetry(cx);
1832
1833 if let Ok(initial_usage) = initial_token_usage {
1834 let usage = thread.cumulative_token_usage - initial_usage;
1835
1836 telemetry::event!(
1837 "Assistant Thread Completion",
1838 thread_id = thread.id().to_string(),
1839 prompt_id = prompt_id,
1840 model = model.telemetry_id(),
1841 model_provider = model.provider_id().to_string(),
1842 input_tokens = usage.input_tokens,
1843 output_tokens = usage.output_tokens,
1844 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1845 cache_read_input_tokens = usage.cache_read_input_tokens,
1846 );
1847 }
1848 })
1849 .ok();
1850 });
1851
1852 self.pending_completions.push(PendingCompletion {
1853 id: pending_completion_id,
1854 queue_state: QueueState::Sending,
1855 _task: task,
1856 });
1857 }
1858
1859 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1860 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1861 println!("No thread summary model");
1862 return;
1863 };
1864
1865 if !model.provider.is_authenticated(cx) {
1866 return;
1867 }
1868
1869 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1870
1871 let request = self.to_summarize_request(
1872 &model.model,
1873 CompletionIntent::ThreadSummarization,
1874 added_user_message.into(),
1875 cx,
1876 );
1877
1878 self.summary = ThreadSummary::Generating;
1879
1880 self.pending_summary = cx.spawn(async move |this, cx| {
1881 let result = async {
1882 let mut messages = model.model.stream_completion(request, &cx).await?;
1883
1884 let mut new_summary = String::new();
1885 while let Some(event) = messages.next().await {
1886 let Ok(event) = event else {
1887 continue;
1888 };
1889 let text = match event {
1890 LanguageModelCompletionEvent::Text(text) => text,
1891 LanguageModelCompletionEvent::StatusUpdate(
1892 CompletionRequestStatus::UsageUpdated { amount, limit },
1893 ) => {
1894 this.update(cx, |thread, _cx| {
1895 thread.last_usage = Some(RequestUsage {
1896 limit,
1897 amount: amount as i32,
1898 });
1899 })?;
1900 continue;
1901 }
1902 _ => continue,
1903 };
1904
1905 let mut lines = text.lines();
1906 new_summary.extend(lines.next());
1907
1908 // Stop if the LLM generated multiple lines.
1909 if lines.next().is_some() {
1910 break;
1911 }
1912 }
1913
1914 anyhow::Ok(new_summary)
1915 }
1916 .await;
1917
1918 this.update(cx, |this, cx| {
1919 match result {
1920 Ok(new_summary) => {
1921 if new_summary.is_empty() {
1922 this.summary = ThreadSummary::Error;
1923 } else {
1924 this.summary = ThreadSummary::Ready(new_summary.into());
1925 }
1926 }
1927 Err(err) => {
1928 this.summary = ThreadSummary::Error;
1929 log::error!("Failed to generate thread summary: {}", err);
1930 }
1931 }
1932 cx.emit(ThreadEvent::SummaryGenerated);
1933 })
1934 .log_err()?;
1935
1936 Some(())
1937 });
1938 }
1939
1940 pub fn start_generating_detailed_summary_if_needed(
1941 &mut self,
1942 thread_store: WeakEntity<ThreadStore>,
1943 cx: &mut Context<Self>,
1944 ) {
1945 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1946 return;
1947 };
1948
1949 match &*self.detailed_summary_rx.borrow() {
1950 DetailedSummaryState::Generating { message_id, .. }
1951 | DetailedSummaryState::Generated { message_id, .. }
1952 if *message_id == last_message_id =>
1953 {
1954 // Already up-to-date
1955 return;
1956 }
1957 _ => {}
1958 }
1959
1960 let Some(ConfiguredModel { model, provider }) =
1961 LanguageModelRegistry::read_global(cx).thread_summary_model()
1962 else {
1963 return;
1964 };
1965
1966 if !provider.is_authenticated(cx) {
1967 return;
1968 }
1969
1970 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
1971
1972 let request = self.to_summarize_request(
1973 &model,
1974 CompletionIntent::ThreadContextSummarization,
1975 added_user_message.into(),
1976 cx,
1977 );
1978
1979 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1980 message_id: last_message_id,
1981 };
1982
1983 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1984 // be better to allow the old task to complete, but this would require logic for choosing
1985 // which result to prefer (the old task could complete after the new one, resulting in a
1986 // stale summary).
1987 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1988 let stream = model.stream_completion_text(request, &cx);
1989 let Some(mut messages) = stream.await.log_err() else {
1990 thread
1991 .update(cx, |thread, _cx| {
1992 *thread.detailed_summary_tx.borrow_mut() =
1993 DetailedSummaryState::NotGenerated;
1994 })
1995 .ok()?;
1996 return None;
1997 };
1998
1999 let mut new_detailed_summary = String::new();
2000
2001 while let Some(chunk) = messages.stream.next().await {
2002 if let Some(chunk) = chunk.log_err() {
2003 new_detailed_summary.push_str(&chunk);
2004 }
2005 }
2006
2007 thread
2008 .update(cx, |thread, _cx| {
2009 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2010 text: new_detailed_summary.into(),
2011 message_id: last_message_id,
2012 };
2013 })
2014 .ok()?;
2015
2016 // Save thread so its summary can be reused later
2017 if let Some(thread) = thread.upgrade() {
2018 if let Ok(Ok(save_task)) = cx.update(|cx| {
2019 thread_store
2020 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2021 }) {
2022 save_task.await.log_err();
2023 }
2024 }
2025
2026 Some(())
2027 });
2028 }
2029
2030 pub async fn wait_for_detailed_summary_or_text(
2031 this: &Entity<Self>,
2032 cx: &mut AsyncApp,
2033 ) -> Option<SharedString> {
2034 let mut detailed_summary_rx = this
2035 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2036 .ok()?;
2037 loop {
2038 match detailed_summary_rx.recv().await? {
2039 DetailedSummaryState::Generating { .. } => {}
2040 DetailedSummaryState::NotGenerated => {
2041 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2042 }
2043 DetailedSummaryState::Generated { text, .. } => return Some(text),
2044 }
2045 }
2046 }
2047
2048 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2049 self.detailed_summary_rx
2050 .borrow()
2051 .text()
2052 .unwrap_or_else(|| self.text().into())
2053 }
2054
2055 pub fn is_generating_detailed_summary(&self) -> bool {
2056 matches!(
2057 &*self.detailed_summary_rx.borrow(),
2058 DetailedSummaryState::Generating { .. }
2059 )
2060 }
2061
2062 pub fn use_pending_tools(
2063 &mut self,
2064 window: Option<AnyWindowHandle>,
2065 cx: &mut Context<Self>,
2066 model: Arc<dyn LanguageModel>,
2067 ) -> Vec<PendingToolUse> {
2068 self.auto_capture_telemetry(cx);
2069 let request =
2070 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2071 let pending_tool_uses = self
2072 .tool_use
2073 .pending_tool_uses()
2074 .into_iter()
2075 .filter(|tool_use| tool_use.status.is_idle())
2076 .cloned()
2077 .collect::<Vec<_>>();
2078
2079 for tool_use in pending_tool_uses.iter() {
2080 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2081 if tool.needs_confirmation(&tool_use.input, cx)
2082 && !AgentSettings::get_global(cx).always_allow_tool_actions
2083 {
2084 self.tool_use.confirm_tool_use(
2085 tool_use.id.clone(),
2086 tool_use.ui_text.clone(),
2087 tool_use.input.clone(),
2088 request.clone(),
2089 tool,
2090 );
2091 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2092 } else {
2093 self.run_tool(
2094 tool_use.id.clone(),
2095 tool_use.ui_text.clone(),
2096 tool_use.input.clone(),
2097 request.clone(),
2098 tool,
2099 model.clone(),
2100 window,
2101 cx,
2102 );
2103 }
2104 } else {
2105 self.handle_hallucinated_tool_use(
2106 tool_use.id.clone(),
2107 tool_use.name.clone(),
2108 window,
2109 cx,
2110 );
2111 }
2112 }
2113
2114 pending_tool_uses
2115 }
2116
2117 pub fn handle_hallucinated_tool_use(
2118 &mut self,
2119 tool_use_id: LanguageModelToolUseId,
2120 hallucinated_tool_name: Arc<str>,
2121 window: Option<AnyWindowHandle>,
2122 cx: &mut Context<Thread>,
2123 ) {
2124 let available_tools = self.tools.read(cx).enabled_tools(cx);
2125
2126 let tool_list = available_tools
2127 .iter()
2128 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2129 .collect::<Vec<_>>()
2130 .join("\n");
2131
2132 let error_message = format!(
2133 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2134 hallucinated_tool_name, tool_list
2135 );
2136
2137 let pending_tool_use = self.tool_use.insert_tool_output(
2138 tool_use_id.clone(),
2139 hallucinated_tool_name,
2140 Err(anyhow!("Missing tool call: {error_message}")),
2141 self.configured_model.as_ref(),
2142 );
2143
2144 cx.emit(ThreadEvent::MissingToolUse {
2145 tool_use_id: tool_use_id.clone(),
2146 ui_text: error_message.into(),
2147 });
2148
2149 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2150 }
2151
2152 pub fn receive_invalid_tool_json(
2153 &mut self,
2154 tool_use_id: LanguageModelToolUseId,
2155 tool_name: Arc<str>,
2156 invalid_json: Arc<str>,
2157 error: String,
2158 window: Option<AnyWindowHandle>,
2159 cx: &mut Context<Thread>,
2160 ) {
2161 log::error!("The model returned invalid input JSON: {invalid_json}");
2162
2163 let pending_tool_use = self.tool_use.insert_tool_output(
2164 tool_use_id.clone(),
2165 tool_name,
2166 Err(anyhow!("Error parsing input JSON: {error}")),
2167 self.configured_model.as_ref(),
2168 );
2169 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2170 pending_tool_use.ui_text.clone()
2171 } else {
2172 log::error!(
2173 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2174 );
2175 format!("Unknown tool {}", tool_use_id).into()
2176 };
2177
2178 cx.emit(ThreadEvent::InvalidToolInput {
2179 tool_use_id: tool_use_id.clone(),
2180 ui_text,
2181 invalid_input_json: invalid_json,
2182 });
2183
2184 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2185 }
2186
2187 pub fn run_tool(
2188 &mut self,
2189 tool_use_id: LanguageModelToolUseId,
2190 ui_text: impl Into<SharedString>,
2191 input: serde_json::Value,
2192 request: Arc<LanguageModelRequest>,
2193 tool: Arc<dyn Tool>,
2194 model: Arc<dyn LanguageModel>,
2195 window: Option<AnyWindowHandle>,
2196 cx: &mut Context<Thread>,
2197 ) {
2198 let task =
2199 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2200 self.tool_use
2201 .run_pending_tool(tool_use_id, ui_text.into(), task);
2202 }
2203
2204 fn spawn_tool_use(
2205 &mut self,
2206 tool_use_id: LanguageModelToolUseId,
2207 request: Arc<LanguageModelRequest>,
2208 input: serde_json::Value,
2209 tool: Arc<dyn Tool>,
2210 model: Arc<dyn LanguageModel>,
2211 window: Option<AnyWindowHandle>,
2212 cx: &mut Context<Thread>,
2213 ) -> Task<()> {
2214 let tool_name: Arc<str> = tool.name().into();
2215
2216 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2217 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2218 } else {
2219 tool.run(
2220 input,
2221 request,
2222 self.project.clone(),
2223 self.action_log.clone(),
2224 model,
2225 window,
2226 cx,
2227 )
2228 };
2229
2230 // Store the card separately if it exists
2231 if let Some(card) = tool_result.card.clone() {
2232 self.tool_use
2233 .insert_tool_result_card(tool_use_id.clone(), card);
2234 }
2235
2236 cx.spawn({
2237 async move |thread: WeakEntity<Thread>, cx| {
2238 let output = tool_result.output.await;
2239
2240 thread
2241 .update(cx, |thread, cx| {
2242 let pending_tool_use = thread.tool_use.insert_tool_output(
2243 tool_use_id.clone(),
2244 tool_name,
2245 output,
2246 thread.configured_model.as_ref(),
2247 );
2248 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2249 })
2250 .ok();
2251 }
2252 })
2253 }
2254
2255 fn tool_finished(
2256 &mut self,
2257 tool_use_id: LanguageModelToolUseId,
2258 pending_tool_use: Option<PendingToolUse>,
2259 canceled: bool,
2260 window: Option<AnyWindowHandle>,
2261 cx: &mut Context<Self>,
2262 ) {
2263 if self.all_tools_finished() {
2264 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2265 if !canceled {
2266 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2267 }
2268 self.auto_capture_telemetry(cx);
2269 }
2270 }
2271
2272 cx.emit(ThreadEvent::ToolFinished {
2273 tool_use_id,
2274 pending_tool_use,
2275 });
2276 }
2277
2278 /// Cancels the last pending completion, if there are any pending.
2279 ///
2280 /// Returns whether a completion was canceled.
2281 pub fn cancel_last_completion(
2282 &mut self,
2283 window: Option<AnyWindowHandle>,
2284 cx: &mut Context<Self>,
2285 ) -> bool {
2286 let mut canceled = self.pending_completions.pop().is_some();
2287
2288 for pending_tool_use in self.tool_use.cancel_pending() {
2289 canceled = true;
2290 self.tool_finished(
2291 pending_tool_use.id.clone(),
2292 Some(pending_tool_use),
2293 true,
2294 window,
2295 cx,
2296 );
2297 }
2298
2299 if canceled {
2300 cx.emit(ThreadEvent::CompletionCanceled);
2301
2302 // When canceled, we always want to insert the checkpoint.
2303 // (We skip over finalize_pending_checkpoint, because it
2304 // would conclude we didn't have anything to insert here.)
2305 if let Some(checkpoint) = self.pending_checkpoint.take() {
2306 self.insert_checkpoint(checkpoint, cx);
2307 }
2308 } else {
2309 self.finalize_pending_checkpoint(cx);
2310 }
2311
2312 canceled
2313 }
2314
2315 /// Signals that any in-progress editing should be canceled.
2316 ///
2317 /// This method is used to notify listeners (like ActiveThread) that
2318 /// they should cancel any editing operations.
2319 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2320 cx.emit(ThreadEvent::CancelEditing);
2321 }
2322
2323 pub fn feedback(&self) -> Option<ThreadFeedback> {
2324 self.feedback
2325 }
2326
2327 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2328 self.message_feedback.get(&message_id).copied()
2329 }
2330
2331 pub fn report_message_feedback(
2332 &mut self,
2333 message_id: MessageId,
2334 feedback: ThreadFeedback,
2335 cx: &mut Context<Self>,
2336 ) -> Task<Result<()>> {
2337 if self.message_feedback.get(&message_id) == Some(&feedback) {
2338 return Task::ready(Ok(()));
2339 }
2340
2341 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2342 let serialized_thread = self.serialize(cx);
2343 let thread_id = self.id().clone();
2344 let client = self.project.read(cx).client();
2345
2346 let enabled_tool_names: Vec<String> = self
2347 .tools()
2348 .read(cx)
2349 .enabled_tools(cx)
2350 .iter()
2351 .map(|tool| tool.name())
2352 .collect();
2353
2354 self.message_feedback.insert(message_id, feedback);
2355
2356 cx.notify();
2357
2358 let message_content = self
2359 .message(message_id)
2360 .map(|msg| msg.to_string())
2361 .unwrap_or_default();
2362
2363 cx.background_spawn(async move {
2364 let final_project_snapshot = final_project_snapshot.await;
2365 let serialized_thread = serialized_thread.await?;
2366 let thread_data =
2367 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2368
2369 let rating = match feedback {
2370 ThreadFeedback::Positive => "positive",
2371 ThreadFeedback::Negative => "negative",
2372 };
2373 telemetry::event!(
2374 "Assistant Thread Rated",
2375 rating,
2376 thread_id,
2377 enabled_tool_names,
2378 message_id = message_id.0,
2379 message_content,
2380 thread_data,
2381 final_project_snapshot
2382 );
2383 client.telemetry().flush_events().await;
2384
2385 Ok(())
2386 })
2387 }
2388
2389 pub fn report_feedback(
2390 &mut self,
2391 feedback: ThreadFeedback,
2392 cx: &mut Context<Self>,
2393 ) -> Task<Result<()>> {
2394 let last_assistant_message_id = self
2395 .messages
2396 .iter()
2397 .rev()
2398 .find(|msg| msg.role == Role::Assistant)
2399 .map(|msg| msg.id);
2400
2401 if let Some(message_id) = last_assistant_message_id {
2402 self.report_message_feedback(message_id, feedback, cx)
2403 } else {
2404 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2405 let serialized_thread = self.serialize(cx);
2406 let thread_id = self.id().clone();
2407 let client = self.project.read(cx).client();
2408 self.feedback = Some(feedback);
2409 cx.notify();
2410
2411 cx.background_spawn(async move {
2412 let final_project_snapshot = final_project_snapshot.await;
2413 let serialized_thread = serialized_thread.await?;
2414 let thread_data = serde_json::to_value(serialized_thread)
2415 .unwrap_or_else(|_| serde_json::Value::Null);
2416
2417 let rating = match feedback {
2418 ThreadFeedback::Positive => "positive",
2419 ThreadFeedback::Negative => "negative",
2420 };
2421 telemetry::event!(
2422 "Assistant Thread Rated",
2423 rating,
2424 thread_id,
2425 thread_data,
2426 final_project_snapshot
2427 );
2428 client.telemetry().flush_events().await;
2429
2430 Ok(())
2431 })
2432 }
2433 }
2434
2435 /// Create a snapshot of the current project state including git information and unsaved buffers.
2436 fn project_snapshot(
2437 project: Entity<Project>,
2438 cx: &mut Context<Self>,
2439 ) -> Task<Arc<ProjectSnapshot>> {
2440 let git_store = project.read(cx).git_store().clone();
2441 let worktree_snapshots: Vec<_> = project
2442 .read(cx)
2443 .visible_worktrees(cx)
2444 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2445 .collect();
2446
2447 cx.spawn(async move |_, cx| {
2448 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2449
2450 let mut unsaved_buffers = Vec::new();
2451 cx.update(|app_cx| {
2452 let buffer_store = project.read(app_cx).buffer_store();
2453 for buffer_handle in buffer_store.read(app_cx).buffers() {
2454 let buffer = buffer_handle.read(app_cx);
2455 if buffer.is_dirty() {
2456 if let Some(file) = buffer.file() {
2457 let path = file.path().to_string_lossy().to_string();
2458 unsaved_buffers.push(path);
2459 }
2460 }
2461 }
2462 })
2463 .ok();
2464
2465 Arc::new(ProjectSnapshot {
2466 worktree_snapshots,
2467 unsaved_buffer_paths: unsaved_buffers,
2468 timestamp: Utc::now(),
2469 })
2470 })
2471 }
2472
2473 fn worktree_snapshot(
2474 worktree: Entity<project::Worktree>,
2475 git_store: Entity<GitStore>,
2476 cx: &App,
2477 ) -> Task<WorktreeSnapshot> {
2478 cx.spawn(async move |cx| {
2479 // Get worktree path and snapshot
2480 let worktree_info = cx.update(|app_cx| {
2481 let worktree = worktree.read(app_cx);
2482 let path = worktree.abs_path().to_string_lossy().to_string();
2483 let snapshot = worktree.snapshot();
2484 (path, snapshot)
2485 });
2486
2487 let Ok((worktree_path, _snapshot)) = worktree_info else {
2488 return WorktreeSnapshot {
2489 worktree_path: String::new(),
2490 git_state: None,
2491 };
2492 };
2493
2494 let git_state = git_store
2495 .update(cx, |git_store, cx| {
2496 git_store
2497 .repositories()
2498 .values()
2499 .find(|repo| {
2500 repo.read(cx)
2501 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2502 .is_some()
2503 })
2504 .cloned()
2505 })
2506 .ok()
2507 .flatten()
2508 .map(|repo| {
2509 repo.update(cx, |repo, _| {
2510 let current_branch =
2511 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2512 repo.send_job(None, |state, _| async move {
2513 let RepositoryState::Local { backend, .. } = state else {
2514 return GitState {
2515 remote_url: None,
2516 head_sha: None,
2517 current_branch,
2518 diff: None,
2519 };
2520 };
2521
2522 let remote_url = backend.remote_url("origin");
2523 let head_sha = backend.head_sha().await;
2524 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2525
2526 GitState {
2527 remote_url,
2528 head_sha,
2529 current_branch,
2530 diff,
2531 }
2532 })
2533 })
2534 });
2535
2536 let git_state = match git_state {
2537 Some(git_state) => match git_state.ok() {
2538 Some(git_state) => git_state.await.ok(),
2539 None => None,
2540 },
2541 None => None,
2542 };
2543
2544 WorktreeSnapshot {
2545 worktree_path,
2546 git_state,
2547 }
2548 })
2549 }
2550
2551 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2552 let mut markdown = Vec::new();
2553
2554 let summary = self.summary().or_default();
2555 writeln!(markdown, "# {summary}\n")?;
2556
2557 for message in self.messages() {
2558 writeln!(
2559 markdown,
2560 "## {role}\n",
2561 role = match message.role {
2562 Role::User => "User",
2563 Role::Assistant => "Agent",
2564 Role::System => "System",
2565 }
2566 )?;
2567
2568 if !message.loaded_context.text.is_empty() {
2569 writeln!(markdown, "{}", message.loaded_context.text)?;
2570 }
2571
2572 if !message.loaded_context.images.is_empty() {
2573 writeln!(
2574 markdown,
2575 "\n{} images attached as context.\n",
2576 message.loaded_context.images.len()
2577 )?;
2578 }
2579
2580 for segment in &message.segments {
2581 match segment {
2582 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2583 MessageSegment::Thinking { text, .. } => {
2584 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2585 }
2586 MessageSegment::RedactedThinking(_) => {}
2587 }
2588 }
2589
2590 for tool_use in self.tool_uses_for_message(message.id, cx) {
2591 writeln!(
2592 markdown,
2593 "**Use Tool: {} ({})**",
2594 tool_use.name, tool_use.id
2595 )?;
2596 writeln!(markdown, "```json")?;
2597 writeln!(
2598 markdown,
2599 "{}",
2600 serde_json::to_string_pretty(&tool_use.input)?
2601 )?;
2602 writeln!(markdown, "```")?;
2603 }
2604
2605 for tool_result in self.tool_results_for_message(message.id) {
2606 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2607 if tool_result.is_error {
2608 write!(markdown, " (Error)")?;
2609 }
2610
2611 writeln!(markdown, "**\n")?;
2612 match &tool_result.content {
2613 LanguageModelToolResultContent::Text(text) => {
2614 writeln!(markdown, "{text}")?;
2615 }
2616 LanguageModelToolResultContent::Image(image) => {
2617 writeln!(markdown, "", image.source)?;
2618 }
2619 }
2620
2621 if let Some(output) = tool_result.output.as_ref() {
2622 writeln!(
2623 markdown,
2624 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2625 serde_json::to_string_pretty(output)?
2626 )?;
2627 }
2628 }
2629 }
2630
2631 Ok(String::from_utf8_lossy(&markdown).to_string())
2632 }
2633
2634 pub fn keep_edits_in_range(
2635 &mut self,
2636 buffer: Entity<language::Buffer>,
2637 buffer_range: Range<language::Anchor>,
2638 cx: &mut Context<Self>,
2639 ) {
2640 self.action_log.update(cx, |action_log, cx| {
2641 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2642 });
2643 }
2644
2645 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2646 self.action_log
2647 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2648 }
2649
2650 pub fn reject_edits_in_ranges(
2651 &mut self,
2652 buffer: Entity<language::Buffer>,
2653 buffer_ranges: Vec<Range<language::Anchor>>,
2654 cx: &mut Context<Self>,
2655 ) -> Task<Result<()>> {
2656 self.action_log.update(cx, |action_log, cx| {
2657 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2658 })
2659 }
2660
2661 pub fn action_log(&self) -> &Entity<ActionLog> {
2662 &self.action_log
2663 }
2664
2665 pub fn project(&self) -> &Entity<Project> {
2666 &self.project
2667 }
2668
2669 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2670 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2671 return;
2672 }
2673
2674 let now = Instant::now();
2675 if let Some(last) = self.last_auto_capture_at {
2676 if now.duration_since(last).as_secs() < 10 {
2677 return;
2678 }
2679 }
2680
2681 self.last_auto_capture_at = Some(now);
2682
2683 let thread_id = self.id().clone();
2684 let github_login = self
2685 .project
2686 .read(cx)
2687 .user_store()
2688 .read(cx)
2689 .current_user()
2690 .map(|user| user.github_login.clone());
2691 let client = self.project.read(cx).client();
2692 let serialize_task = self.serialize(cx);
2693
2694 cx.background_executor()
2695 .spawn(async move {
2696 if let Ok(serialized_thread) = serialize_task.await {
2697 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2698 telemetry::event!(
2699 "Agent Thread Auto-Captured",
2700 thread_id = thread_id.to_string(),
2701 thread_data = thread_data,
2702 auto_capture_reason = "tracked_user",
2703 github_login = github_login
2704 );
2705
2706 client.telemetry().flush_events().await;
2707 }
2708 }
2709 })
2710 .detach();
2711 }
2712
2713 pub fn cumulative_token_usage(&self) -> TokenUsage {
2714 self.cumulative_token_usage
2715 }
2716
2717 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2718 let Some(model) = self.configured_model.as_ref() else {
2719 return TotalTokenUsage::default();
2720 };
2721
2722 let max = model.model.max_token_count();
2723
2724 let index = self
2725 .messages
2726 .iter()
2727 .position(|msg| msg.id == message_id)
2728 .unwrap_or(0);
2729
2730 if index == 0 {
2731 return TotalTokenUsage { total: 0, max };
2732 }
2733
2734 let token_usage = &self
2735 .request_token_usage
2736 .get(index - 1)
2737 .cloned()
2738 .unwrap_or_default();
2739
2740 TotalTokenUsage {
2741 total: token_usage.total_tokens() as usize,
2742 max,
2743 }
2744 }
2745
2746 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2747 let model = self.configured_model.as_ref()?;
2748
2749 let max = model.model.max_token_count();
2750
2751 if let Some(exceeded_error) = &self.exceeded_window_error {
2752 if model.model.id() == exceeded_error.model_id {
2753 return Some(TotalTokenUsage {
2754 total: exceeded_error.token_count,
2755 max,
2756 });
2757 }
2758 }
2759
2760 let total = self
2761 .token_usage_at_last_message()
2762 .unwrap_or_default()
2763 .total_tokens() as usize;
2764
2765 Some(TotalTokenUsage { total, max })
2766 }
2767
2768 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2769 self.request_token_usage
2770 .get(self.messages.len().saturating_sub(1))
2771 .or_else(|| self.request_token_usage.last())
2772 .cloned()
2773 }
2774
2775 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2776 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2777 self.request_token_usage
2778 .resize(self.messages.len(), placeholder);
2779
2780 if let Some(last) = self.request_token_usage.last_mut() {
2781 *last = token_usage;
2782 }
2783 }
2784
2785 pub fn deny_tool_use(
2786 &mut self,
2787 tool_use_id: LanguageModelToolUseId,
2788 tool_name: Arc<str>,
2789 window: Option<AnyWindowHandle>,
2790 cx: &mut Context<Self>,
2791 ) {
2792 let err = Err(anyhow::anyhow!(
2793 "Permission to run tool action denied by user"
2794 ));
2795
2796 self.tool_use.insert_tool_output(
2797 tool_use_id.clone(),
2798 tool_name,
2799 err,
2800 self.configured_model.as_ref(),
2801 );
2802 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2803 }
2804}
2805
2806#[derive(Debug, Clone, Error)]
2807pub enum ThreadError {
2808 #[error("Payment required")]
2809 PaymentRequired,
2810 #[error("Model request limit reached")]
2811 ModelRequestLimitReached { plan: Plan },
2812 #[error("Message {header}: {message}")]
2813 Message {
2814 header: SharedString,
2815 message: SharedString,
2816 },
2817}
2818
2819#[derive(Debug, Clone)]
2820pub enum ThreadEvent {
2821 ShowError(ThreadError),
2822 StreamedCompletion,
2823 ReceivedTextChunk,
2824 NewRequest,
2825 StreamedAssistantText(MessageId, String),
2826 StreamedAssistantThinking(MessageId, String),
2827 StreamedToolUse {
2828 tool_use_id: LanguageModelToolUseId,
2829 ui_text: Arc<str>,
2830 input: serde_json::Value,
2831 },
2832 MissingToolUse {
2833 tool_use_id: LanguageModelToolUseId,
2834 ui_text: Arc<str>,
2835 },
2836 InvalidToolInput {
2837 tool_use_id: LanguageModelToolUseId,
2838 ui_text: Arc<str>,
2839 invalid_input_json: Arc<str>,
2840 },
2841 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2842 MessageAdded(MessageId),
2843 MessageEdited(MessageId),
2844 MessageDeleted(MessageId),
2845 SummaryGenerated,
2846 SummaryChanged,
2847 UsePendingTools {
2848 tool_uses: Vec<PendingToolUse>,
2849 },
2850 ToolFinished {
2851 #[allow(unused)]
2852 tool_use_id: LanguageModelToolUseId,
2853 /// The pending tool use that corresponds to this tool.
2854 pending_tool_use: Option<PendingToolUse>,
2855 },
2856 CheckpointChanged,
2857 ToolConfirmationNeeded,
2858 ToolUseLimitReached,
2859 CancelEditing,
2860 CompletionCanceled,
2861}
2862
2863impl EventEmitter<ThreadEvent> for Thread {}
2864
2865struct PendingCompletion {
2866 id: usize,
2867 queue_state: QueueState,
2868 _task: Task<()>,
2869}
2870
2871#[cfg(test)]
2872mod tests {
2873 use super::*;
2874 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2875 use agent_settings::{AgentSettings, LanguageModelParameters};
2876 use assistant_tool::ToolRegistry;
2877 use editor::EditorSettings;
2878 use gpui::TestAppContext;
2879 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2880 use project::{FakeFs, Project};
2881 use prompt_store::PromptBuilder;
2882 use serde_json::json;
2883 use settings::{Settings, SettingsStore};
2884 use std::sync::Arc;
2885 use theme::ThemeSettings;
2886 use util::path;
2887 use workspace::Workspace;
2888
2889 #[gpui::test]
2890 async fn test_message_with_context(cx: &mut TestAppContext) {
2891 init_test_settings(cx);
2892
2893 let project = create_test_project(
2894 cx,
2895 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2896 )
2897 .await;
2898
2899 let (_workspace, _thread_store, thread, context_store, model) =
2900 setup_test_environment(cx, project.clone()).await;
2901
2902 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2903 .await
2904 .unwrap();
2905
2906 let context =
2907 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2908 let loaded_context = cx
2909 .update(|cx| load_context(vec![context], &project, &None, cx))
2910 .await;
2911
2912 // Insert user message with context
2913 let message_id = thread.update(cx, |thread, cx| {
2914 thread.insert_user_message(
2915 "Please explain this code",
2916 loaded_context,
2917 None,
2918 Vec::new(),
2919 cx,
2920 )
2921 });
2922
2923 // Check content and context in message object
2924 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2925
2926 // Use different path format strings based on platform for the test
2927 #[cfg(windows)]
2928 let path_part = r"test\code.rs";
2929 #[cfg(not(windows))]
2930 let path_part = "test/code.rs";
2931
2932 let expected_context = format!(
2933 r#"
2934<context>
2935The following items were attached by the user. They are up-to-date and don't need to be re-read.
2936
2937<files>
2938```rs {path_part}
2939fn main() {{
2940 println!("Hello, world!");
2941}}
2942```
2943</files>
2944</context>
2945"#
2946 );
2947
2948 assert_eq!(message.role, Role::User);
2949 assert_eq!(message.segments.len(), 1);
2950 assert_eq!(
2951 message.segments[0],
2952 MessageSegment::Text("Please explain this code".to_string())
2953 );
2954 assert_eq!(message.loaded_context.text, expected_context);
2955
2956 // Check message in request
2957 let request = thread.update(cx, |thread, cx| {
2958 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2959 });
2960
2961 assert_eq!(request.messages.len(), 2);
2962 let expected_full_message = format!("{}Please explain this code", expected_context);
2963 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2964 }
2965
2966 #[gpui::test]
2967 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2968 init_test_settings(cx);
2969
2970 let project = create_test_project(
2971 cx,
2972 json!({
2973 "file1.rs": "fn function1() {}\n",
2974 "file2.rs": "fn function2() {}\n",
2975 "file3.rs": "fn function3() {}\n",
2976 "file4.rs": "fn function4() {}\n",
2977 }),
2978 )
2979 .await;
2980
2981 let (_, _thread_store, thread, context_store, model) =
2982 setup_test_environment(cx, project.clone()).await;
2983
2984 // First message with context 1
2985 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2986 .await
2987 .unwrap();
2988 let new_contexts = context_store.update(cx, |store, cx| {
2989 store.new_context_for_thread(thread.read(cx), None)
2990 });
2991 assert_eq!(new_contexts.len(), 1);
2992 let loaded_context = cx
2993 .update(|cx| load_context(new_contexts, &project, &None, cx))
2994 .await;
2995 let message1_id = thread.update(cx, |thread, cx| {
2996 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2997 });
2998
2999 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3000 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3001 .await
3002 .unwrap();
3003 let new_contexts = context_store.update(cx, |store, cx| {
3004 store.new_context_for_thread(thread.read(cx), None)
3005 });
3006 assert_eq!(new_contexts.len(), 1);
3007 let loaded_context = cx
3008 .update(|cx| load_context(new_contexts, &project, &None, cx))
3009 .await;
3010 let message2_id = thread.update(cx, |thread, cx| {
3011 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3012 });
3013
3014 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3015 //
3016 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3017 .await
3018 .unwrap();
3019 let new_contexts = context_store.update(cx, |store, cx| {
3020 store.new_context_for_thread(thread.read(cx), None)
3021 });
3022 assert_eq!(new_contexts.len(), 1);
3023 let loaded_context = cx
3024 .update(|cx| load_context(new_contexts, &project, &None, cx))
3025 .await;
3026 let message3_id = thread.update(cx, |thread, cx| {
3027 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3028 });
3029
3030 // Check what contexts are included in each message
3031 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3032 (
3033 thread.message(message1_id).unwrap().clone(),
3034 thread.message(message2_id).unwrap().clone(),
3035 thread.message(message3_id).unwrap().clone(),
3036 )
3037 });
3038
3039 // First message should include context 1
3040 assert!(message1.loaded_context.text.contains("file1.rs"));
3041
3042 // Second message should include only context 2 (not 1)
3043 assert!(!message2.loaded_context.text.contains("file1.rs"));
3044 assert!(message2.loaded_context.text.contains("file2.rs"));
3045
3046 // Third message should include only context 3 (not 1 or 2)
3047 assert!(!message3.loaded_context.text.contains("file1.rs"));
3048 assert!(!message3.loaded_context.text.contains("file2.rs"));
3049 assert!(message3.loaded_context.text.contains("file3.rs"));
3050
3051 // Check entire request to make sure all contexts are properly included
3052 let request = thread.update(cx, |thread, cx| {
3053 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3054 });
3055
3056 // The request should contain all 3 messages
3057 assert_eq!(request.messages.len(), 4);
3058
3059 // Check that the contexts are properly formatted in each message
3060 assert!(request.messages[1].string_contents().contains("file1.rs"));
3061 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3062 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3063
3064 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3065 assert!(request.messages[2].string_contents().contains("file2.rs"));
3066 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3067
3068 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3069 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3070 assert!(request.messages[3].string_contents().contains("file3.rs"));
3071
3072 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3073 .await
3074 .unwrap();
3075 let new_contexts = context_store.update(cx, |store, cx| {
3076 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3077 });
3078 assert_eq!(new_contexts.len(), 3);
3079 let loaded_context = cx
3080 .update(|cx| load_context(new_contexts, &project, &None, cx))
3081 .await
3082 .loaded_context;
3083
3084 assert!(!loaded_context.text.contains("file1.rs"));
3085 assert!(loaded_context.text.contains("file2.rs"));
3086 assert!(loaded_context.text.contains("file3.rs"));
3087 assert!(loaded_context.text.contains("file4.rs"));
3088
3089 let new_contexts = context_store.update(cx, |store, cx| {
3090 // Remove file4.rs
3091 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3092 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3093 });
3094 assert_eq!(new_contexts.len(), 2);
3095 let loaded_context = cx
3096 .update(|cx| load_context(new_contexts, &project, &None, cx))
3097 .await
3098 .loaded_context;
3099
3100 assert!(!loaded_context.text.contains("file1.rs"));
3101 assert!(loaded_context.text.contains("file2.rs"));
3102 assert!(loaded_context.text.contains("file3.rs"));
3103 assert!(!loaded_context.text.contains("file4.rs"));
3104
3105 let new_contexts = context_store.update(cx, |store, cx| {
3106 // Remove file3.rs
3107 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3108 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3109 });
3110 assert_eq!(new_contexts.len(), 1);
3111 let loaded_context = cx
3112 .update(|cx| load_context(new_contexts, &project, &None, cx))
3113 .await
3114 .loaded_context;
3115
3116 assert!(!loaded_context.text.contains("file1.rs"));
3117 assert!(loaded_context.text.contains("file2.rs"));
3118 assert!(!loaded_context.text.contains("file3.rs"));
3119 assert!(!loaded_context.text.contains("file4.rs"));
3120 }
3121
3122 #[gpui::test]
3123 async fn test_message_without_files(cx: &mut TestAppContext) {
3124 init_test_settings(cx);
3125
3126 let project = create_test_project(
3127 cx,
3128 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3129 )
3130 .await;
3131
3132 let (_, _thread_store, thread, _context_store, model) =
3133 setup_test_environment(cx, project.clone()).await;
3134
3135 // Insert user message without any context (empty context vector)
3136 let message_id = thread.update(cx, |thread, cx| {
3137 thread.insert_user_message(
3138 "What is the best way to learn Rust?",
3139 ContextLoadResult::default(),
3140 None,
3141 Vec::new(),
3142 cx,
3143 )
3144 });
3145
3146 // Check content and context in message object
3147 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3148
3149 // Context should be empty when no files are included
3150 assert_eq!(message.role, Role::User);
3151 assert_eq!(message.segments.len(), 1);
3152 assert_eq!(
3153 message.segments[0],
3154 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3155 );
3156 assert_eq!(message.loaded_context.text, "");
3157
3158 // Check message in request
3159 let request = thread.update(cx, |thread, cx| {
3160 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3161 });
3162
3163 assert_eq!(request.messages.len(), 2);
3164 assert_eq!(
3165 request.messages[1].string_contents(),
3166 "What is the best way to learn Rust?"
3167 );
3168
3169 // Add second message, also without context
3170 let message2_id = thread.update(cx, |thread, cx| {
3171 thread.insert_user_message(
3172 "Are there any good books?",
3173 ContextLoadResult::default(),
3174 None,
3175 Vec::new(),
3176 cx,
3177 )
3178 });
3179
3180 let message2 =
3181 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3182 assert_eq!(message2.loaded_context.text, "");
3183
3184 // Check that both messages appear in the request
3185 let request = thread.update(cx, |thread, cx| {
3186 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3187 });
3188
3189 assert_eq!(request.messages.len(), 3);
3190 assert_eq!(
3191 request.messages[1].string_contents(),
3192 "What is the best way to learn Rust?"
3193 );
3194 assert_eq!(
3195 request.messages[2].string_contents(),
3196 "Are there any good books?"
3197 );
3198 }
3199
3200 #[gpui::test]
3201 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3202 init_test_settings(cx);
3203
3204 let project = create_test_project(
3205 cx,
3206 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3207 )
3208 .await;
3209
3210 let (_workspace, _thread_store, thread, context_store, model) =
3211 setup_test_environment(cx, project.clone()).await;
3212
3213 // Open buffer and add it to context
3214 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3215 .await
3216 .unwrap();
3217
3218 let context =
3219 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3220 let loaded_context = cx
3221 .update(|cx| load_context(vec![context], &project, &None, cx))
3222 .await;
3223
3224 // Insert user message with the buffer as context
3225 thread.update(cx, |thread, cx| {
3226 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3227 });
3228
3229 // Create a request and check that it doesn't have a stale buffer warning yet
3230 let initial_request = thread.update(cx, |thread, cx| {
3231 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3232 });
3233
3234 // Make sure we don't have a stale file warning yet
3235 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3236 msg.string_contents()
3237 .contains("These files changed since last read:")
3238 });
3239 assert!(
3240 !has_stale_warning,
3241 "Should not have stale buffer warning before buffer is modified"
3242 );
3243
3244 // Modify the buffer
3245 buffer.update(cx, |buffer, cx| {
3246 // Find a position at the end of line 1
3247 buffer.edit(
3248 [(1..1, "\n println!(\"Added a new line\");\n")],
3249 None,
3250 cx,
3251 );
3252 });
3253
3254 // Insert another user message without context
3255 thread.update(cx, |thread, cx| {
3256 thread.insert_user_message(
3257 "What does the code do now?",
3258 ContextLoadResult::default(),
3259 None,
3260 Vec::new(),
3261 cx,
3262 )
3263 });
3264
3265 // Create a new request and check for the stale buffer warning
3266 let new_request = thread.update(cx, |thread, cx| {
3267 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3268 });
3269
3270 // We should have a stale file warning as the last message
3271 let last_message = new_request
3272 .messages
3273 .last()
3274 .expect("Request should have messages");
3275
3276 // The last message should be the stale buffer notification
3277 assert_eq!(last_message.role, Role::User);
3278
3279 // Check the exact content of the message
3280 let expected_content = "These files changed since last read:\n- code.rs\n";
3281 assert_eq!(
3282 last_message.string_contents(),
3283 expected_content,
3284 "Last message should be exactly the stale buffer notification"
3285 );
3286 }
3287
3288 #[gpui::test]
3289 async fn test_temperature_setting(cx: &mut TestAppContext) {
3290 init_test_settings(cx);
3291
3292 let project = create_test_project(
3293 cx,
3294 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3295 )
3296 .await;
3297
3298 let (_workspace, _thread_store, thread, _context_store, model) =
3299 setup_test_environment(cx, project.clone()).await;
3300
3301 // Both model and provider
3302 cx.update(|cx| {
3303 AgentSettings::override_global(
3304 AgentSettings {
3305 model_parameters: vec![LanguageModelParameters {
3306 provider: Some(model.provider_id().0.to_string().into()),
3307 model: Some(model.id().0.clone()),
3308 temperature: Some(0.66),
3309 }],
3310 ..AgentSettings::get_global(cx).clone()
3311 },
3312 cx,
3313 );
3314 });
3315
3316 let request = thread.update(cx, |thread, cx| {
3317 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3318 });
3319 assert_eq!(request.temperature, Some(0.66));
3320
3321 // Only model
3322 cx.update(|cx| {
3323 AgentSettings::override_global(
3324 AgentSettings {
3325 model_parameters: vec![LanguageModelParameters {
3326 provider: None,
3327 model: Some(model.id().0.clone()),
3328 temperature: Some(0.66),
3329 }],
3330 ..AgentSettings::get_global(cx).clone()
3331 },
3332 cx,
3333 );
3334 });
3335
3336 let request = thread.update(cx, |thread, cx| {
3337 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3338 });
3339 assert_eq!(request.temperature, Some(0.66));
3340
3341 // Only provider
3342 cx.update(|cx| {
3343 AgentSettings::override_global(
3344 AgentSettings {
3345 model_parameters: vec![LanguageModelParameters {
3346 provider: Some(model.provider_id().0.to_string().into()),
3347 model: None,
3348 temperature: Some(0.66),
3349 }],
3350 ..AgentSettings::get_global(cx).clone()
3351 },
3352 cx,
3353 );
3354 });
3355
3356 let request = thread.update(cx, |thread, cx| {
3357 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3358 });
3359 assert_eq!(request.temperature, Some(0.66));
3360
3361 // Same model name, different provider
3362 cx.update(|cx| {
3363 AgentSettings::override_global(
3364 AgentSettings {
3365 model_parameters: vec![LanguageModelParameters {
3366 provider: Some("anthropic".into()),
3367 model: Some(model.id().0.clone()),
3368 temperature: Some(0.66),
3369 }],
3370 ..AgentSettings::get_global(cx).clone()
3371 },
3372 cx,
3373 );
3374 });
3375
3376 let request = thread.update(cx, |thread, cx| {
3377 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3378 });
3379 assert_eq!(request.temperature, None);
3380 }
3381
3382 #[gpui::test]
3383 async fn test_thread_summary(cx: &mut TestAppContext) {
3384 init_test_settings(cx);
3385
3386 let project = create_test_project(cx, json!({})).await;
3387
3388 let (_, _thread_store, thread, _context_store, model) =
3389 setup_test_environment(cx, project.clone()).await;
3390
3391 // Initial state should be pending
3392 thread.read_with(cx, |thread, _| {
3393 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3394 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3395 });
3396
3397 // Manually setting the summary should not be allowed in this state
3398 thread.update(cx, |thread, cx| {
3399 thread.set_summary("This should not work", cx);
3400 });
3401
3402 thread.read_with(cx, |thread, _| {
3403 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3404 });
3405
3406 // Send a message
3407 thread.update(cx, |thread, cx| {
3408 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3409 thread.send_to_model(
3410 model.clone(),
3411 CompletionIntent::ThreadSummarization,
3412 None,
3413 cx,
3414 );
3415 });
3416
3417 let fake_model = model.as_fake();
3418 simulate_successful_response(&fake_model, cx);
3419
3420 // Should start generating summary when there are >= 2 messages
3421 thread.read_with(cx, |thread, _| {
3422 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3423 });
3424
3425 // Should not be able to set the summary while generating
3426 thread.update(cx, |thread, cx| {
3427 thread.set_summary("This should not work either", cx);
3428 });
3429
3430 thread.read_with(cx, |thread, _| {
3431 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3432 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3433 });
3434
3435 cx.run_until_parked();
3436 fake_model.stream_last_completion_response("Brief");
3437 fake_model.stream_last_completion_response(" Introduction");
3438 fake_model.end_last_completion_stream();
3439 cx.run_until_parked();
3440
3441 // Summary should be set
3442 thread.read_with(cx, |thread, _| {
3443 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3444 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3445 });
3446
3447 // Now we should be able to set a summary
3448 thread.update(cx, |thread, cx| {
3449 thread.set_summary("Brief Intro", cx);
3450 });
3451
3452 thread.read_with(cx, |thread, _| {
3453 assert_eq!(thread.summary().or_default(), "Brief Intro");
3454 });
3455
3456 // Test setting an empty summary (should default to DEFAULT)
3457 thread.update(cx, |thread, cx| {
3458 thread.set_summary("", cx);
3459 });
3460
3461 thread.read_with(cx, |thread, _| {
3462 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3463 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3464 });
3465 }
3466
3467 #[gpui::test]
3468 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3469 init_test_settings(cx);
3470
3471 let project = create_test_project(cx, json!({})).await;
3472
3473 let (_, _thread_store, thread, _context_store, model) =
3474 setup_test_environment(cx, project.clone()).await;
3475
3476 test_summarize_error(&model, &thread, cx);
3477
3478 // Now we should be able to set a summary
3479 thread.update(cx, |thread, cx| {
3480 thread.set_summary("Brief Intro", cx);
3481 });
3482
3483 thread.read_with(cx, |thread, _| {
3484 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3485 assert_eq!(thread.summary().or_default(), "Brief Intro");
3486 });
3487 }
3488
3489 #[gpui::test]
3490 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3491 init_test_settings(cx);
3492
3493 let project = create_test_project(cx, json!({})).await;
3494
3495 let (_, _thread_store, thread, _context_store, model) =
3496 setup_test_environment(cx, project.clone()).await;
3497
3498 test_summarize_error(&model, &thread, cx);
3499
3500 // Sending another message should not trigger another summarize request
3501 thread.update(cx, |thread, cx| {
3502 thread.insert_user_message(
3503 "How are you?",
3504 ContextLoadResult::default(),
3505 None,
3506 vec![],
3507 cx,
3508 );
3509 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3510 });
3511
3512 let fake_model = model.as_fake();
3513 simulate_successful_response(&fake_model, cx);
3514
3515 thread.read_with(cx, |thread, _| {
3516 // State is still Error, not Generating
3517 assert!(matches!(thread.summary(), ThreadSummary::Error));
3518 });
3519
3520 // But the summarize request can be invoked manually
3521 thread.update(cx, |thread, cx| {
3522 thread.summarize(cx);
3523 });
3524
3525 thread.read_with(cx, |thread, _| {
3526 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3527 });
3528
3529 cx.run_until_parked();
3530 fake_model.stream_last_completion_response("A successful summary");
3531 fake_model.end_last_completion_stream();
3532 cx.run_until_parked();
3533
3534 thread.read_with(cx, |thread, _| {
3535 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3536 assert_eq!(thread.summary().or_default(), "A successful summary");
3537 });
3538 }
3539
3540 fn test_summarize_error(
3541 model: &Arc<dyn LanguageModel>,
3542 thread: &Entity<Thread>,
3543 cx: &mut TestAppContext,
3544 ) {
3545 thread.update(cx, |thread, cx| {
3546 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3547 thread.send_to_model(
3548 model.clone(),
3549 CompletionIntent::ThreadSummarization,
3550 None,
3551 cx,
3552 );
3553 });
3554
3555 let fake_model = model.as_fake();
3556 simulate_successful_response(&fake_model, cx);
3557
3558 thread.read_with(cx, |thread, _| {
3559 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3560 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3561 });
3562
3563 // Simulate summary request ending
3564 cx.run_until_parked();
3565 fake_model.end_last_completion_stream();
3566 cx.run_until_parked();
3567
3568 // State is set to Error and default message
3569 thread.read_with(cx, |thread, _| {
3570 assert!(matches!(thread.summary(), ThreadSummary::Error));
3571 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3572 });
3573 }
3574
3575 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3576 cx.run_until_parked();
3577 fake_model.stream_last_completion_response("Assistant response");
3578 fake_model.end_last_completion_stream();
3579 cx.run_until_parked();
3580 }
3581
3582 fn init_test_settings(cx: &mut TestAppContext) {
3583 cx.update(|cx| {
3584 let settings_store = SettingsStore::test(cx);
3585 cx.set_global(settings_store);
3586 language::init(cx);
3587 Project::init_settings(cx);
3588 AgentSettings::register(cx);
3589 prompt_store::init(cx);
3590 thread_store::init(cx);
3591 workspace::init_settings(cx);
3592 language_model::init_settings(cx);
3593 ThemeSettings::register(cx);
3594 EditorSettings::register(cx);
3595 ToolRegistry::default_global(cx);
3596 });
3597 }
3598
3599 // Helper to create a test project with test files
3600 async fn create_test_project(
3601 cx: &mut TestAppContext,
3602 files: serde_json::Value,
3603 ) -> Entity<Project> {
3604 let fs = FakeFs::new(cx.executor());
3605 fs.insert_tree(path!("/test"), files).await;
3606 Project::test(fs, [path!("/test").as_ref()], cx).await
3607 }
3608
3609 async fn setup_test_environment(
3610 cx: &mut TestAppContext,
3611 project: Entity<Project>,
3612 ) -> (
3613 Entity<Workspace>,
3614 Entity<ThreadStore>,
3615 Entity<Thread>,
3616 Entity<ContextStore>,
3617 Arc<dyn LanguageModel>,
3618 ) {
3619 let (workspace, cx) =
3620 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3621
3622 let thread_store = cx
3623 .update(|_, cx| {
3624 ThreadStore::load(
3625 project.clone(),
3626 cx.new(|_| ToolWorkingSet::default()),
3627 None,
3628 Arc::new(PromptBuilder::new(None).unwrap()),
3629 cx,
3630 )
3631 })
3632 .await
3633 .unwrap();
3634
3635 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3636 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3637
3638 let provider = Arc::new(FakeLanguageModelProvider);
3639 let model = provider.test_model();
3640 let model: Arc<dyn LanguageModel> = Arc::new(model);
3641
3642 cx.update(|_, cx| {
3643 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3644 registry.set_default_model(
3645 Some(ConfiguredModel {
3646 provider: provider.clone(),
3647 model: model.clone(),
3648 }),
3649 cx,
3650 );
3651 registry.set_thread_summary_model(
3652 Some(ConfiguredModel {
3653 provider,
3654 model: model.clone(),
3655 }),
3656 cx,
3657 );
3658 })
3659 });
3660
3661 (workspace, thread_store, thread, context_store, model)
3662 }
3663
3664 async fn add_file_to_context(
3665 project: &Entity<Project>,
3666 context_store: &Entity<ContextStore>,
3667 path: &str,
3668 cx: &mut TestAppContext,
3669 ) -> Result<Entity<language::Buffer>> {
3670 let buffer_path = project
3671 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3672 .unwrap();
3673
3674 let buffer = project
3675 .update(cx, |project, cx| {
3676 project.open_buffer(buffer_path.clone(), cx)
3677 })
3678 .await
3679 .unwrap();
3680
3681 context_store.update(cx, |context_store, cx| {
3682 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3683 });
3684
3685 Ok(buffer)
3686 }
3687}