1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|tool_use| tool_use.status.is_error())
875 }
876
877 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
878 self.tool_use.tool_uses_for_message(id, cx)
879 }
880
881 pub fn tool_results_for_message(
882 &self,
883 assistant_message_id: MessageId,
884 ) -> Vec<&LanguageModelToolResult> {
885 self.tool_use.tool_results_for_message(assistant_message_id)
886 }
887
888 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
889 self.tool_use.tool_result(id)
890 }
891
892 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
893 match &self.tool_use.tool_result(id)?.content {
894 LanguageModelToolResultContent::Text(text)
895 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
896 Some(text)
897 }
898 LanguageModelToolResultContent::Image(_) => {
899 // TODO: We should display image
900 None
901 }
902 }
903 }
904
905 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
906 self.tool_use.tool_result_card(id).cloned()
907 }
908
909 /// Return tools that are both enabled and supported by the model
910 pub fn available_tools(
911 &self,
912 cx: &App,
913 model: Arc<dyn LanguageModel>,
914 ) -> Vec<LanguageModelRequestTool> {
915 if model.supports_tools() {
916 self.tools()
917 .read(cx)
918 .enabled_tools(cx)
919 .into_iter()
920 .filter_map(|tool| {
921 // Skip tools that cannot be supported
922 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
923 Some(LanguageModelRequestTool {
924 name: tool.name(),
925 description: tool.description(),
926 input_schema,
927 })
928 })
929 .collect()
930 } else {
931 Vec::default()
932 }
933 }
934
935 pub fn insert_user_message(
936 &mut self,
937 text: impl Into<String>,
938 loaded_context: ContextLoadResult,
939 git_checkpoint: Option<GitStoreCheckpoint>,
940 creases: Vec<MessageCrease>,
941 cx: &mut Context<Self>,
942 ) -> MessageId {
943 if !loaded_context.referenced_buffers.is_empty() {
944 self.action_log.update(cx, |log, cx| {
945 for buffer in loaded_context.referenced_buffers {
946 log.buffer_read(buffer, cx);
947 }
948 });
949 }
950
951 let message_id = self.insert_message(
952 Role::User,
953 vec![MessageSegment::Text(text.into())],
954 loaded_context.loaded_context,
955 creases,
956 false,
957 cx,
958 );
959
960 if let Some(git_checkpoint) = git_checkpoint {
961 self.pending_checkpoint = Some(ThreadCheckpoint {
962 message_id,
963 git_checkpoint,
964 });
965 }
966
967 self.auto_capture_telemetry(cx);
968
969 message_id
970 }
971
972 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
973 let id = self.insert_message(
974 Role::User,
975 vec![MessageSegment::Text("Continue where you left off".into())],
976 LoadedContext::default(),
977 vec![],
978 true,
979 cx,
980 );
981 self.pending_checkpoint = None;
982
983 id
984 }
985
986 pub fn insert_assistant_message(
987 &mut self,
988 segments: Vec<MessageSegment>,
989 cx: &mut Context<Self>,
990 ) -> MessageId {
991 self.insert_message(
992 Role::Assistant,
993 segments,
994 LoadedContext::default(),
995 Vec::new(),
996 false,
997 cx,
998 )
999 }
1000
1001 pub fn insert_message(
1002 &mut self,
1003 role: Role,
1004 segments: Vec<MessageSegment>,
1005 loaded_context: LoadedContext,
1006 creases: Vec<MessageCrease>,
1007 is_hidden: bool,
1008 cx: &mut Context<Self>,
1009 ) -> MessageId {
1010 let id = self.next_message_id.post_inc();
1011 self.messages.push(Message {
1012 id,
1013 role,
1014 segments,
1015 loaded_context,
1016 creases,
1017 is_hidden,
1018 });
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageAdded(id));
1021 id
1022 }
1023
1024 pub fn edit_message(
1025 &mut self,
1026 id: MessageId,
1027 new_role: Role,
1028 new_segments: Vec<MessageSegment>,
1029 creases: Vec<MessageCrease>,
1030 loaded_context: Option<LoadedContext>,
1031 checkpoint: Option<GitStoreCheckpoint>,
1032 cx: &mut Context<Self>,
1033 ) -> bool {
1034 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1035 return false;
1036 };
1037 message.role = new_role;
1038 message.segments = new_segments;
1039 message.creases = creases;
1040 if let Some(context) = loaded_context {
1041 message.loaded_context = context;
1042 }
1043 if let Some(git_checkpoint) = checkpoint {
1044 self.checkpoints_by_message.insert(
1045 id,
1046 ThreadCheckpoint {
1047 message_id: id,
1048 git_checkpoint,
1049 },
1050 );
1051 }
1052 self.touch_updated_at();
1053 cx.emit(ThreadEvent::MessageEdited(id));
1054 true
1055 }
1056
1057 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1058 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1059 return false;
1060 };
1061 self.messages.remove(index);
1062 self.touch_updated_at();
1063 cx.emit(ThreadEvent::MessageDeleted(id));
1064 true
1065 }
1066
1067 /// Returns the representation of this [`Thread`] in a textual form.
1068 ///
1069 /// This is the representation we use when attaching a thread as context to another thread.
1070 pub fn text(&self) -> String {
1071 let mut text = String::new();
1072
1073 for message in &self.messages {
1074 text.push_str(match message.role {
1075 language_model::Role::User => "User:",
1076 language_model::Role::Assistant => "Agent:",
1077 language_model::Role::System => "System:",
1078 });
1079 text.push('\n');
1080
1081 for segment in &message.segments {
1082 match segment {
1083 MessageSegment::Text(content) => text.push_str(content),
1084 MessageSegment::Thinking { text: content, .. } => {
1085 text.push_str(&format!("<think>{}</think>", content))
1086 }
1087 MessageSegment::RedactedThinking(_) => {}
1088 }
1089 }
1090 text.push('\n');
1091 }
1092
1093 text
1094 }
1095
1096 /// Serializes this thread into a format for storage or telemetry.
1097 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1098 let initial_project_snapshot = self.initial_project_snapshot.clone();
1099 cx.spawn(async move |this, cx| {
1100 let initial_project_snapshot = initial_project_snapshot.await;
1101 this.read_with(cx, |this, cx| SerializedThread {
1102 version: SerializedThread::VERSION.to_string(),
1103 summary: this.summary().or_default(),
1104 updated_at: this.updated_at(),
1105 messages: this
1106 .messages()
1107 .map(|message| SerializedMessage {
1108 id: message.id,
1109 role: message.role,
1110 segments: message
1111 .segments
1112 .iter()
1113 .map(|segment| match segment {
1114 MessageSegment::Text(text) => {
1115 SerializedMessageSegment::Text { text: text.clone() }
1116 }
1117 MessageSegment::Thinking { text, signature } => {
1118 SerializedMessageSegment::Thinking {
1119 text: text.clone(),
1120 signature: signature.clone(),
1121 }
1122 }
1123 MessageSegment::RedactedThinking(data) => {
1124 SerializedMessageSegment::RedactedThinking {
1125 data: data.clone(),
1126 }
1127 }
1128 })
1129 .collect(),
1130 tool_uses: this
1131 .tool_uses_for_message(message.id, cx)
1132 .into_iter()
1133 .map(|tool_use| SerializedToolUse {
1134 id: tool_use.id,
1135 name: tool_use.name,
1136 input: tool_use.input,
1137 })
1138 .collect(),
1139 tool_results: this
1140 .tool_results_for_message(message.id)
1141 .into_iter()
1142 .map(|tool_result| SerializedToolResult {
1143 tool_use_id: tool_result.tool_use_id.clone(),
1144 is_error: tool_result.is_error,
1145 content: tool_result.content.clone(),
1146 output: tool_result.output.clone(),
1147 })
1148 .collect(),
1149 context: message.loaded_context.text.clone(),
1150 creases: message
1151 .creases
1152 .iter()
1153 .map(|crease| SerializedCrease {
1154 start: crease.range.start,
1155 end: crease.range.end,
1156 icon_path: crease.metadata.icon_path.clone(),
1157 label: crease.metadata.label.clone(),
1158 })
1159 .collect(),
1160 is_hidden: message.is_hidden,
1161 })
1162 .collect(),
1163 initial_project_snapshot,
1164 cumulative_token_usage: this.cumulative_token_usage,
1165 request_token_usage: this.request_token_usage.clone(),
1166 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1167 exceeded_window_error: this.exceeded_window_error.clone(),
1168 model: this
1169 .configured_model
1170 .as_ref()
1171 .map(|model| SerializedLanguageModel {
1172 provider: model.provider.id().0.to_string(),
1173 model: model.model.id().0.to_string(),
1174 }),
1175 completion_mode: Some(this.completion_mode),
1176 tool_use_limit_reached: this.tool_use_limit_reached,
1177 })
1178 })
1179 }
1180
1181 pub fn remaining_turns(&self) -> u32 {
1182 self.remaining_turns
1183 }
1184
1185 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1186 self.remaining_turns = remaining_turns;
1187 }
1188
1189 pub fn send_to_model(
1190 &mut self,
1191 model: Arc<dyn LanguageModel>,
1192 intent: CompletionIntent,
1193 window: Option<AnyWindowHandle>,
1194 cx: &mut Context<Self>,
1195 ) {
1196 if self.remaining_turns == 0 {
1197 return;
1198 }
1199
1200 self.remaining_turns -= 1;
1201
1202 let request = self.to_completion_request(model.clone(), intent, cx);
1203
1204 self.stream_completion(request, model, window, cx);
1205 }
1206
1207 pub fn used_tools_since_last_user_message(&self) -> bool {
1208 for message in self.messages.iter().rev() {
1209 if self.tool_use.message_has_tool_results(message.id) {
1210 return true;
1211 } else if message.role == Role::User {
1212 return false;
1213 }
1214 }
1215
1216 false
1217 }
1218
1219 pub fn to_completion_request(
1220 &self,
1221 model: Arc<dyn LanguageModel>,
1222 intent: CompletionIntent,
1223 cx: &mut Context<Self>,
1224 ) -> LanguageModelRequest {
1225 let mut request = LanguageModelRequest {
1226 thread_id: Some(self.id.to_string()),
1227 prompt_id: Some(self.last_prompt_id.to_string()),
1228 intent: Some(intent),
1229 mode: None,
1230 messages: vec![],
1231 tools: Vec::new(),
1232 tool_choice: None,
1233 stop: Vec::new(),
1234 temperature: AgentSettings::temperature_for_model(&model, cx),
1235 };
1236
1237 let available_tools = self.available_tools(cx, model.clone());
1238 let available_tool_names = available_tools
1239 .iter()
1240 .map(|tool| tool.name.clone())
1241 .collect();
1242
1243 let model_context = &ModelContext {
1244 available_tools: available_tool_names,
1245 };
1246
1247 if let Some(project_context) = self.project_context.borrow().as_ref() {
1248 match self
1249 .prompt_builder
1250 .generate_assistant_system_prompt(project_context, model_context)
1251 {
1252 Err(err) => {
1253 let message = format!("{err:?}").into();
1254 log::error!("{message}");
1255 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1256 header: "Error generating system prompt".into(),
1257 message,
1258 }));
1259 }
1260 Ok(system_prompt) => {
1261 request.messages.push(LanguageModelRequestMessage {
1262 role: Role::System,
1263 content: vec![MessageContent::Text(system_prompt)],
1264 cache: true,
1265 });
1266 }
1267 }
1268 } else {
1269 let message = "Context for system prompt unexpectedly not ready.".into();
1270 log::error!("{message}");
1271 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1272 header: "Error generating system prompt".into(),
1273 message,
1274 }));
1275 }
1276
1277 let mut message_ix_to_cache = None;
1278 for message in &self.messages {
1279 let mut request_message = LanguageModelRequestMessage {
1280 role: message.role,
1281 content: Vec::new(),
1282 cache: false,
1283 };
1284
1285 message
1286 .loaded_context
1287 .add_to_request_message(&mut request_message);
1288
1289 for segment in &message.segments {
1290 match segment {
1291 MessageSegment::Text(text) => {
1292 if !text.is_empty() {
1293 request_message
1294 .content
1295 .push(MessageContent::Text(text.into()));
1296 }
1297 }
1298 MessageSegment::Thinking { text, signature } => {
1299 if !text.is_empty() {
1300 request_message.content.push(MessageContent::Thinking {
1301 text: text.into(),
1302 signature: signature.clone(),
1303 });
1304 }
1305 }
1306 MessageSegment::RedactedThinking(data) => {
1307 request_message
1308 .content
1309 .push(MessageContent::RedactedThinking(data.clone()));
1310 }
1311 };
1312 }
1313
1314 let mut cache_message = true;
1315 let mut tool_results_message = LanguageModelRequestMessage {
1316 role: Role::User,
1317 content: Vec::new(),
1318 cache: false,
1319 };
1320 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1321 if let Some(tool_result) = tool_result {
1322 request_message
1323 .content
1324 .push(MessageContent::ToolUse(tool_use.clone()));
1325 tool_results_message
1326 .content
1327 .push(MessageContent::ToolResult(LanguageModelToolResult {
1328 tool_use_id: tool_use.id.clone(),
1329 tool_name: tool_result.tool_name.clone(),
1330 is_error: tool_result.is_error,
1331 content: if tool_result.content.is_empty() {
1332 // Surprisingly, the API fails if we return an empty string here.
1333 // It thinks we are sending a tool use without a tool result.
1334 "<Tool returned an empty string>".into()
1335 } else {
1336 tool_result.content.clone()
1337 },
1338 output: None,
1339 }));
1340 } else {
1341 cache_message = false;
1342 log::debug!(
1343 "skipped tool use {:?} because it is still pending",
1344 tool_use
1345 );
1346 }
1347 }
1348
1349 if cache_message {
1350 message_ix_to_cache = Some(request.messages.len());
1351 }
1352 request.messages.push(request_message);
1353
1354 if !tool_results_message.content.is_empty() {
1355 if cache_message {
1356 message_ix_to_cache = Some(request.messages.len());
1357 }
1358 request.messages.push(tool_results_message);
1359 }
1360 }
1361
1362 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1363 if let Some(message_ix_to_cache) = message_ix_to_cache {
1364 request.messages[message_ix_to_cache].cache = true;
1365 }
1366
1367 self.attached_tracked_files_state(&mut request.messages, cx);
1368
1369 request.tools = available_tools;
1370 request.mode = if model.supports_max_mode() {
1371 Some(self.completion_mode.into())
1372 } else {
1373 Some(CompletionMode::Normal.into())
1374 };
1375
1376 request
1377 }
1378
1379 fn to_summarize_request(
1380 &self,
1381 model: &Arc<dyn LanguageModel>,
1382 intent: CompletionIntent,
1383 added_user_message: String,
1384 cx: &App,
1385 ) -> LanguageModelRequest {
1386 let mut request = LanguageModelRequest {
1387 thread_id: None,
1388 prompt_id: None,
1389 intent: Some(intent),
1390 mode: None,
1391 messages: vec![],
1392 tools: Vec::new(),
1393 tool_choice: None,
1394 stop: Vec::new(),
1395 temperature: AgentSettings::temperature_for_model(model, cx),
1396 };
1397
1398 for message in &self.messages {
1399 let mut request_message = LanguageModelRequestMessage {
1400 role: message.role,
1401 content: Vec::new(),
1402 cache: false,
1403 };
1404
1405 for segment in &message.segments {
1406 match segment {
1407 MessageSegment::Text(text) => request_message
1408 .content
1409 .push(MessageContent::Text(text.clone())),
1410 MessageSegment::Thinking { .. } => {}
1411 MessageSegment::RedactedThinking(_) => {}
1412 }
1413 }
1414
1415 if request_message.content.is_empty() {
1416 continue;
1417 }
1418
1419 request.messages.push(request_message);
1420 }
1421
1422 request.messages.push(LanguageModelRequestMessage {
1423 role: Role::User,
1424 content: vec![MessageContent::Text(added_user_message)],
1425 cache: false,
1426 });
1427
1428 request
1429 }
1430
1431 fn attached_tracked_files_state(
1432 &self,
1433 messages: &mut Vec<LanguageModelRequestMessage>,
1434 cx: &App,
1435 ) {
1436 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1437
1438 let mut stale_message = String::new();
1439
1440 let action_log = self.action_log.read(cx);
1441
1442 for stale_file in action_log.stale_buffers(cx) {
1443 let Some(file) = stale_file.read(cx).file() else {
1444 continue;
1445 };
1446
1447 if stale_message.is_empty() {
1448 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1449 }
1450
1451 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1452 }
1453
1454 let mut content = Vec::with_capacity(2);
1455
1456 if !stale_message.is_empty() {
1457 content.push(stale_message.into());
1458 }
1459
1460 if !content.is_empty() {
1461 let context_message = LanguageModelRequestMessage {
1462 role: Role::User,
1463 content,
1464 cache: false,
1465 };
1466
1467 messages.push(context_message);
1468 }
1469 }
1470
1471 pub fn stream_completion(
1472 &mut self,
1473 request: LanguageModelRequest,
1474 model: Arc<dyn LanguageModel>,
1475 window: Option<AnyWindowHandle>,
1476 cx: &mut Context<Self>,
1477 ) {
1478 self.tool_use_limit_reached = false;
1479
1480 let pending_completion_id = post_inc(&mut self.completion_count);
1481 let mut request_callback_parameters = if self.request_callback.is_some() {
1482 Some((request.clone(), Vec::new()))
1483 } else {
1484 None
1485 };
1486 let prompt_id = self.last_prompt_id.clone();
1487 let tool_use_metadata = ToolUseMetadata {
1488 model: model.clone(),
1489 thread_id: self.id.clone(),
1490 prompt_id: prompt_id.clone(),
1491 };
1492
1493 self.last_received_chunk_at = Some(Instant::now());
1494
1495 let task = cx.spawn(async move |thread, cx| {
1496 let stream_completion_future = model.stream_completion(request, &cx);
1497 let initial_token_usage =
1498 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1499 let stream_completion = async {
1500 let mut events = stream_completion_future.await?;
1501
1502 let mut stop_reason = StopReason::EndTurn;
1503 let mut current_token_usage = TokenUsage::default();
1504
1505 thread
1506 .update(cx, |_thread, cx| {
1507 cx.emit(ThreadEvent::NewRequest);
1508 })
1509 .ok();
1510
1511 let mut request_assistant_message_id = None;
1512
1513 while let Some(event) = events.next().await {
1514 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1515 response_events
1516 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1517 }
1518
1519 thread.update(cx, |thread, cx| {
1520 let event = match event {
1521 Ok(event) => event,
1522 Err(LanguageModelCompletionError::BadInputJson {
1523 id,
1524 tool_name,
1525 raw_input: invalid_input_json,
1526 json_parse_error,
1527 }) => {
1528 thread.receive_invalid_tool_json(
1529 id,
1530 tool_name,
1531 invalid_input_json,
1532 json_parse_error,
1533 window,
1534 cx,
1535 );
1536 return Ok(());
1537 }
1538 Err(LanguageModelCompletionError::Other(error)) => {
1539 return Err(error);
1540 }
1541 };
1542
1543 match event {
1544 LanguageModelCompletionEvent::StartMessage { .. } => {
1545 request_assistant_message_id =
1546 Some(thread.insert_assistant_message(
1547 vec![MessageSegment::Text(String::new())],
1548 cx,
1549 ));
1550 }
1551 LanguageModelCompletionEvent::Stop(reason) => {
1552 stop_reason = reason;
1553 }
1554 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1555 thread.update_token_usage_at_last_message(token_usage);
1556 thread.cumulative_token_usage = thread.cumulative_token_usage
1557 + token_usage
1558 - current_token_usage;
1559 current_token_usage = token_usage;
1560 }
1561 LanguageModelCompletionEvent::Text(chunk) => {
1562 thread.received_chunk();
1563
1564 cx.emit(ThreadEvent::ReceivedTextChunk);
1565 if let Some(last_message) = thread.messages.last_mut() {
1566 if last_message.role == Role::Assistant
1567 && !thread.tool_use.has_tool_results(last_message.id)
1568 {
1569 last_message.push_text(&chunk);
1570 cx.emit(ThreadEvent::StreamedAssistantText(
1571 last_message.id,
1572 chunk,
1573 ));
1574 } else {
1575 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1576 // of a new Assistant response.
1577 //
1578 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1579 // will result in duplicating the text of the chunk in the rendered Markdown.
1580 request_assistant_message_id =
1581 Some(thread.insert_assistant_message(
1582 vec![MessageSegment::Text(chunk.to_string())],
1583 cx,
1584 ));
1585 };
1586 }
1587 }
1588 LanguageModelCompletionEvent::Thinking {
1589 text: chunk,
1590 signature,
1591 } => {
1592 thread.received_chunk();
1593
1594 if let Some(last_message) = thread.messages.last_mut() {
1595 if last_message.role == Role::Assistant
1596 && !thread.tool_use.has_tool_results(last_message.id)
1597 {
1598 last_message.push_thinking(&chunk, signature);
1599 cx.emit(ThreadEvent::StreamedAssistantThinking(
1600 last_message.id,
1601 chunk,
1602 ));
1603 } else {
1604 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1605 // of a new Assistant response.
1606 //
1607 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1608 // will result in duplicating the text of the chunk in the rendered Markdown.
1609 request_assistant_message_id =
1610 Some(thread.insert_assistant_message(
1611 vec![MessageSegment::Thinking {
1612 text: chunk.to_string(),
1613 signature,
1614 }],
1615 cx,
1616 ));
1617 };
1618 }
1619 }
1620 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1621 let last_assistant_message_id = request_assistant_message_id
1622 .unwrap_or_else(|| {
1623 let new_assistant_message_id =
1624 thread.insert_assistant_message(vec![], cx);
1625 request_assistant_message_id =
1626 Some(new_assistant_message_id);
1627 new_assistant_message_id
1628 });
1629
1630 let tool_use_id = tool_use.id.clone();
1631 let streamed_input = if tool_use.is_input_complete {
1632 None
1633 } else {
1634 Some((&tool_use.input).clone())
1635 };
1636
1637 let ui_text = thread.tool_use.request_tool_use(
1638 last_assistant_message_id,
1639 tool_use,
1640 tool_use_metadata.clone(),
1641 cx,
1642 );
1643
1644 if let Some(input) = streamed_input {
1645 cx.emit(ThreadEvent::StreamedToolUse {
1646 tool_use_id,
1647 ui_text,
1648 input,
1649 });
1650 }
1651 }
1652 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1653 if let Some(completion) = thread
1654 .pending_completions
1655 .iter_mut()
1656 .find(|completion| completion.id == pending_completion_id)
1657 {
1658 match status_update {
1659 CompletionRequestStatus::Queued {
1660 position,
1661 } => {
1662 completion.queue_state = QueueState::Queued { position };
1663 }
1664 CompletionRequestStatus::Started => {
1665 completion.queue_state = QueueState::Started;
1666 }
1667 CompletionRequestStatus::Failed {
1668 code, message, request_id
1669 } => {
1670 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1671 }
1672 CompletionRequestStatus::UsageUpdated {
1673 amount, limit
1674 } => {
1675 let usage = RequestUsage { limit, amount: amount as i32 };
1676
1677 thread.last_usage = Some(usage);
1678 }
1679 CompletionRequestStatus::ToolUseLimitReached => {
1680 thread.tool_use_limit_reached = true;
1681 }
1682 }
1683 }
1684 }
1685 }
1686
1687 thread.touch_updated_at();
1688 cx.emit(ThreadEvent::StreamedCompletion);
1689 cx.notify();
1690
1691 thread.auto_capture_telemetry(cx);
1692 Ok(())
1693 })??;
1694
1695 smol::future::yield_now().await;
1696 }
1697
1698 thread.update(cx, |thread, cx| {
1699 thread.last_received_chunk_at = None;
1700 thread
1701 .pending_completions
1702 .retain(|completion| completion.id != pending_completion_id);
1703
1704 // If there is a response without tool use, summarize the message. Otherwise,
1705 // allow two tool uses before summarizing.
1706 if matches!(thread.summary, ThreadSummary::Pending)
1707 && thread.messages.len() >= 2
1708 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1709 {
1710 thread.summarize(cx);
1711 }
1712 })?;
1713
1714 anyhow::Ok(stop_reason)
1715 };
1716
1717 let result = stream_completion.await;
1718
1719 thread
1720 .update(cx, |thread, cx| {
1721 thread.finalize_pending_checkpoint(cx);
1722 match result.as_ref() {
1723 Ok(stop_reason) => match stop_reason {
1724 StopReason::ToolUse => {
1725 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1726 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1727 }
1728 StopReason::EndTurn | StopReason::MaxTokens => {
1729 thread.project.update(cx, |project, cx| {
1730 project.set_agent_location(None, cx);
1731 });
1732 }
1733 StopReason::Refusal => {
1734 thread.project.update(cx, |project, cx| {
1735 project.set_agent_location(None, cx);
1736 });
1737
1738 // Remove the turn that was refused.
1739 //
1740 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1741 {
1742 let mut messages_to_remove = Vec::new();
1743
1744 for (ix, message) in thread.messages.iter().enumerate().rev() {
1745 messages_to_remove.push(message.id);
1746
1747 if message.role == Role::User {
1748 if ix == 0 {
1749 break;
1750 }
1751
1752 if let Some(prev_message) = thread.messages.get(ix - 1) {
1753 if prev_message.role == Role::Assistant {
1754 break;
1755 }
1756 }
1757 }
1758 }
1759
1760 for message_id in messages_to_remove {
1761 thread.delete_message(message_id, cx);
1762 }
1763 }
1764
1765 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1766 header: "Language model refusal".into(),
1767 message: "Model refused to generate content for safety reasons.".into(),
1768 }));
1769 }
1770 },
1771 Err(error) => {
1772 thread.project.update(cx, |project, cx| {
1773 project.set_agent_location(None, cx);
1774 });
1775
1776 if error.is::<PaymentRequiredError>() {
1777 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1778 } else if let Some(error) =
1779 error.downcast_ref::<ModelRequestLimitReachedError>()
1780 {
1781 cx.emit(ThreadEvent::ShowError(
1782 ThreadError::ModelRequestLimitReached { plan: error.plan },
1783 ));
1784 } else if let Some(known_error) =
1785 error.downcast_ref::<LanguageModelKnownError>()
1786 {
1787 match known_error {
1788 LanguageModelKnownError::ContextWindowLimitExceeded {
1789 tokens,
1790 } => {
1791 thread.exceeded_window_error = Some(ExceededWindowError {
1792 model_id: model.id(),
1793 token_count: *tokens,
1794 });
1795 cx.notify();
1796 }
1797 }
1798 } else {
1799 let error_message = error
1800 .chain()
1801 .map(|err| err.to_string())
1802 .collect::<Vec<_>>()
1803 .join("\n");
1804 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1805 header: "Error interacting with language model".into(),
1806 message: SharedString::from(error_message.clone()),
1807 }));
1808 }
1809
1810 thread.cancel_last_completion(window, cx);
1811 }
1812 }
1813
1814 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1815
1816 if let Some((request_callback, (request, response_events))) = thread
1817 .request_callback
1818 .as_mut()
1819 .zip(request_callback_parameters.as_ref())
1820 {
1821 request_callback(request, response_events);
1822 }
1823
1824 thread.auto_capture_telemetry(cx);
1825
1826 if let Ok(initial_usage) = initial_token_usage {
1827 let usage = thread.cumulative_token_usage - initial_usage;
1828
1829 telemetry::event!(
1830 "Assistant Thread Completion",
1831 thread_id = thread.id().to_string(),
1832 prompt_id = prompt_id,
1833 model = model.telemetry_id(),
1834 model_provider = model.provider_id().to_string(),
1835 input_tokens = usage.input_tokens,
1836 output_tokens = usage.output_tokens,
1837 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1838 cache_read_input_tokens = usage.cache_read_input_tokens,
1839 );
1840 }
1841 })
1842 .ok();
1843 });
1844
1845 self.pending_completions.push(PendingCompletion {
1846 id: pending_completion_id,
1847 queue_state: QueueState::Sending,
1848 _task: task,
1849 });
1850 }
1851
1852 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1853 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1854 println!("No thread summary model");
1855 return;
1856 };
1857
1858 if !model.provider.is_authenticated(cx) {
1859 return;
1860 }
1861
1862 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1863 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1864 If the conversation is about a specific subject, include it in the title. \
1865 Be descriptive. DO NOT speak in the first person.";
1866
1867 let request = self.to_summarize_request(
1868 &model.model,
1869 CompletionIntent::ThreadSummarization,
1870 added_user_message.into(),
1871 cx,
1872 );
1873
1874 self.summary = ThreadSummary::Generating;
1875
1876 self.pending_summary = cx.spawn(async move |this, cx| {
1877 let result = async {
1878 let mut messages = model.model.stream_completion(request, &cx).await?;
1879
1880 let mut new_summary = String::new();
1881 while let Some(event) = messages.next().await {
1882 let Ok(event) = event else {
1883 continue;
1884 };
1885 let text = match event {
1886 LanguageModelCompletionEvent::Text(text) => text,
1887 LanguageModelCompletionEvent::StatusUpdate(
1888 CompletionRequestStatus::UsageUpdated { amount, limit },
1889 ) => {
1890 this.update(cx, |thread, _cx| {
1891 thread.last_usage = Some(RequestUsage {
1892 limit,
1893 amount: amount as i32,
1894 });
1895 })?;
1896 continue;
1897 }
1898 _ => continue,
1899 };
1900
1901 let mut lines = text.lines();
1902 new_summary.extend(lines.next());
1903
1904 // Stop if the LLM generated multiple lines.
1905 if lines.next().is_some() {
1906 break;
1907 }
1908 }
1909
1910 anyhow::Ok(new_summary)
1911 }
1912 .await;
1913
1914 this.update(cx, |this, cx| {
1915 match result {
1916 Ok(new_summary) => {
1917 if new_summary.is_empty() {
1918 this.summary = ThreadSummary::Error;
1919 } else {
1920 this.summary = ThreadSummary::Ready(new_summary.into());
1921 }
1922 }
1923 Err(err) => {
1924 this.summary = ThreadSummary::Error;
1925 log::error!("Failed to generate thread summary: {}", err);
1926 }
1927 }
1928 cx.emit(ThreadEvent::SummaryGenerated);
1929 })
1930 .log_err()?;
1931
1932 Some(())
1933 });
1934 }
1935
1936 pub fn start_generating_detailed_summary_if_needed(
1937 &mut self,
1938 thread_store: WeakEntity<ThreadStore>,
1939 cx: &mut Context<Self>,
1940 ) {
1941 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1942 return;
1943 };
1944
1945 match &*self.detailed_summary_rx.borrow() {
1946 DetailedSummaryState::Generating { message_id, .. }
1947 | DetailedSummaryState::Generated { message_id, .. }
1948 if *message_id == last_message_id =>
1949 {
1950 // Already up-to-date
1951 return;
1952 }
1953 _ => {}
1954 }
1955
1956 let Some(ConfiguredModel { model, provider }) =
1957 LanguageModelRegistry::read_global(cx).thread_summary_model()
1958 else {
1959 return;
1960 };
1961
1962 if !provider.is_authenticated(cx) {
1963 return;
1964 }
1965
1966 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1967 1. A brief overview of what was discussed\n\
1968 2. Key facts or information discovered\n\
1969 3. Outcomes or conclusions reached\n\
1970 4. Any action items or next steps if any\n\
1971 Format it in Markdown with headings and bullet points.";
1972
1973 let request = self.to_summarize_request(
1974 &model,
1975 CompletionIntent::ThreadContextSummarization,
1976 added_user_message.into(),
1977 cx,
1978 );
1979
1980 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1981 message_id: last_message_id,
1982 };
1983
1984 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1985 // be better to allow the old task to complete, but this would require logic for choosing
1986 // which result to prefer (the old task could complete after the new one, resulting in a
1987 // stale summary).
1988 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1989 let stream = model.stream_completion_text(request, &cx);
1990 let Some(mut messages) = stream.await.log_err() else {
1991 thread
1992 .update(cx, |thread, _cx| {
1993 *thread.detailed_summary_tx.borrow_mut() =
1994 DetailedSummaryState::NotGenerated;
1995 })
1996 .ok()?;
1997 return None;
1998 };
1999
2000 let mut new_detailed_summary = String::new();
2001
2002 while let Some(chunk) = messages.stream.next().await {
2003 if let Some(chunk) = chunk.log_err() {
2004 new_detailed_summary.push_str(&chunk);
2005 }
2006 }
2007
2008 thread
2009 .update(cx, |thread, _cx| {
2010 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2011 text: new_detailed_summary.into(),
2012 message_id: last_message_id,
2013 };
2014 })
2015 .ok()?;
2016
2017 // Save thread so its summary can be reused later
2018 if let Some(thread) = thread.upgrade() {
2019 if let Ok(Ok(save_task)) = cx.update(|cx| {
2020 thread_store
2021 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2022 }) {
2023 save_task.await.log_err();
2024 }
2025 }
2026
2027 Some(())
2028 });
2029 }
2030
2031 pub async fn wait_for_detailed_summary_or_text(
2032 this: &Entity<Self>,
2033 cx: &mut AsyncApp,
2034 ) -> Option<SharedString> {
2035 let mut detailed_summary_rx = this
2036 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2037 .ok()?;
2038 loop {
2039 match detailed_summary_rx.recv().await? {
2040 DetailedSummaryState::Generating { .. } => {}
2041 DetailedSummaryState::NotGenerated => {
2042 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2043 }
2044 DetailedSummaryState::Generated { text, .. } => return Some(text),
2045 }
2046 }
2047 }
2048
2049 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2050 self.detailed_summary_rx
2051 .borrow()
2052 .text()
2053 .unwrap_or_else(|| self.text().into())
2054 }
2055
2056 pub fn is_generating_detailed_summary(&self) -> bool {
2057 matches!(
2058 &*self.detailed_summary_rx.borrow(),
2059 DetailedSummaryState::Generating { .. }
2060 )
2061 }
2062
2063 pub fn use_pending_tools(
2064 &mut self,
2065 window: Option<AnyWindowHandle>,
2066 cx: &mut Context<Self>,
2067 model: Arc<dyn LanguageModel>,
2068 ) -> Vec<PendingToolUse> {
2069 self.auto_capture_telemetry(cx);
2070 let request =
2071 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2072 let pending_tool_uses = self
2073 .tool_use
2074 .pending_tool_uses()
2075 .into_iter()
2076 .filter(|tool_use| tool_use.status.is_idle())
2077 .cloned()
2078 .collect::<Vec<_>>();
2079
2080 for tool_use in pending_tool_uses.iter() {
2081 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2082 if tool.needs_confirmation(&tool_use.input, cx)
2083 && !AgentSettings::get_global(cx).always_allow_tool_actions
2084 {
2085 self.tool_use.confirm_tool_use(
2086 tool_use.id.clone(),
2087 tool_use.ui_text.clone(),
2088 tool_use.input.clone(),
2089 request.clone(),
2090 tool,
2091 );
2092 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2093 } else {
2094 self.run_tool(
2095 tool_use.id.clone(),
2096 tool_use.ui_text.clone(),
2097 tool_use.input.clone(),
2098 request.clone(),
2099 tool,
2100 model.clone(),
2101 window,
2102 cx,
2103 );
2104 }
2105 } else {
2106 self.handle_hallucinated_tool_use(
2107 tool_use.id.clone(),
2108 tool_use.name.clone(),
2109 window,
2110 cx,
2111 );
2112 }
2113 }
2114
2115 pending_tool_uses
2116 }
2117
2118 pub fn handle_hallucinated_tool_use(
2119 &mut self,
2120 tool_use_id: LanguageModelToolUseId,
2121 hallucinated_tool_name: Arc<str>,
2122 window: Option<AnyWindowHandle>,
2123 cx: &mut Context<Thread>,
2124 ) {
2125 let available_tools = self.tools.read(cx).enabled_tools(cx);
2126
2127 let tool_list = available_tools
2128 .iter()
2129 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2130 .collect::<Vec<_>>()
2131 .join("\n");
2132
2133 let error_message = format!(
2134 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2135 hallucinated_tool_name, tool_list
2136 );
2137
2138 let pending_tool_use = self.tool_use.insert_tool_output(
2139 tool_use_id.clone(),
2140 hallucinated_tool_name,
2141 Err(anyhow!("Missing tool call: {error_message}")),
2142 self.configured_model.as_ref(),
2143 );
2144
2145 cx.emit(ThreadEvent::MissingToolUse {
2146 tool_use_id: tool_use_id.clone(),
2147 ui_text: error_message.into(),
2148 });
2149
2150 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2151 }
2152
2153 pub fn receive_invalid_tool_json(
2154 &mut self,
2155 tool_use_id: LanguageModelToolUseId,
2156 tool_name: Arc<str>,
2157 invalid_json: Arc<str>,
2158 error: String,
2159 window: Option<AnyWindowHandle>,
2160 cx: &mut Context<Thread>,
2161 ) {
2162 log::error!("The model returned invalid input JSON: {invalid_json}");
2163
2164 let pending_tool_use = self.tool_use.insert_tool_output(
2165 tool_use_id.clone(),
2166 tool_name,
2167 Err(anyhow!("Error parsing input JSON: {error}")),
2168 self.configured_model.as_ref(),
2169 );
2170 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2171 pending_tool_use.ui_text.clone()
2172 } else {
2173 log::error!(
2174 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2175 );
2176 format!("Unknown tool {}", tool_use_id).into()
2177 };
2178
2179 cx.emit(ThreadEvent::InvalidToolInput {
2180 tool_use_id: tool_use_id.clone(),
2181 ui_text,
2182 invalid_input_json: invalid_json,
2183 });
2184
2185 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2186 }
2187
2188 pub fn run_tool(
2189 &mut self,
2190 tool_use_id: LanguageModelToolUseId,
2191 ui_text: impl Into<SharedString>,
2192 input: serde_json::Value,
2193 request: Arc<LanguageModelRequest>,
2194 tool: Arc<dyn Tool>,
2195 model: Arc<dyn LanguageModel>,
2196 window: Option<AnyWindowHandle>,
2197 cx: &mut Context<Thread>,
2198 ) {
2199 let task =
2200 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2201 self.tool_use
2202 .run_pending_tool(tool_use_id, ui_text.into(), task);
2203 }
2204
2205 fn spawn_tool_use(
2206 &mut self,
2207 tool_use_id: LanguageModelToolUseId,
2208 request: Arc<LanguageModelRequest>,
2209 input: serde_json::Value,
2210 tool: Arc<dyn Tool>,
2211 model: Arc<dyn LanguageModel>,
2212 window: Option<AnyWindowHandle>,
2213 cx: &mut Context<Thread>,
2214 ) -> Task<()> {
2215 let tool_name: Arc<str> = tool.name().into();
2216
2217 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2218 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2219 } else {
2220 tool.run(
2221 input,
2222 request,
2223 self.project.clone(),
2224 self.action_log.clone(),
2225 model,
2226 window,
2227 cx,
2228 )
2229 };
2230
2231 // Store the card separately if it exists
2232 if let Some(card) = tool_result.card.clone() {
2233 self.tool_use
2234 .insert_tool_result_card(tool_use_id.clone(), card);
2235 }
2236
2237 cx.spawn({
2238 async move |thread: WeakEntity<Thread>, cx| {
2239 let output = tool_result.output.await;
2240
2241 thread
2242 .update(cx, |thread, cx| {
2243 let pending_tool_use = thread.tool_use.insert_tool_output(
2244 tool_use_id.clone(),
2245 tool_name,
2246 output,
2247 thread.configured_model.as_ref(),
2248 );
2249 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2250 })
2251 .ok();
2252 }
2253 })
2254 }
2255
2256 fn tool_finished(
2257 &mut self,
2258 tool_use_id: LanguageModelToolUseId,
2259 pending_tool_use: Option<PendingToolUse>,
2260 canceled: bool,
2261 window: Option<AnyWindowHandle>,
2262 cx: &mut Context<Self>,
2263 ) {
2264 if self.all_tools_finished() {
2265 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2266 if !canceled {
2267 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2268 }
2269 self.auto_capture_telemetry(cx);
2270 }
2271 }
2272
2273 cx.emit(ThreadEvent::ToolFinished {
2274 tool_use_id,
2275 pending_tool_use,
2276 });
2277 }
2278
2279 /// Cancels the last pending completion, if there are any pending.
2280 ///
2281 /// Returns whether a completion was canceled.
2282 pub fn cancel_last_completion(
2283 &mut self,
2284 window: Option<AnyWindowHandle>,
2285 cx: &mut Context<Self>,
2286 ) -> bool {
2287 let mut canceled = self.pending_completions.pop().is_some();
2288
2289 for pending_tool_use in self.tool_use.cancel_pending() {
2290 canceled = true;
2291 self.tool_finished(
2292 pending_tool_use.id.clone(),
2293 Some(pending_tool_use),
2294 true,
2295 window,
2296 cx,
2297 );
2298 }
2299
2300 if canceled {
2301 cx.emit(ThreadEvent::CompletionCanceled);
2302
2303 // When canceled, we always want to insert the checkpoint.
2304 // (We skip over finalize_pending_checkpoint, because it
2305 // would conclude we didn't have anything to insert here.)
2306 if let Some(checkpoint) = self.pending_checkpoint.take() {
2307 self.insert_checkpoint(checkpoint, cx);
2308 }
2309 } else {
2310 self.finalize_pending_checkpoint(cx);
2311 }
2312
2313 canceled
2314 }
2315
2316 /// Signals that any in-progress editing should be canceled.
2317 ///
2318 /// This method is used to notify listeners (like ActiveThread) that
2319 /// they should cancel any editing operations.
2320 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2321 cx.emit(ThreadEvent::CancelEditing);
2322 }
2323
2324 pub fn feedback(&self) -> Option<ThreadFeedback> {
2325 self.feedback
2326 }
2327
2328 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2329 self.message_feedback.get(&message_id).copied()
2330 }
2331
2332 pub fn report_message_feedback(
2333 &mut self,
2334 message_id: MessageId,
2335 feedback: ThreadFeedback,
2336 cx: &mut Context<Self>,
2337 ) -> Task<Result<()>> {
2338 if self.message_feedback.get(&message_id) == Some(&feedback) {
2339 return Task::ready(Ok(()));
2340 }
2341
2342 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2343 let serialized_thread = self.serialize(cx);
2344 let thread_id = self.id().clone();
2345 let client = self.project.read(cx).client();
2346
2347 let enabled_tool_names: Vec<String> = self
2348 .tools()
2349 .read(cx)
2350 .enabled_tools(cx)
2351 .iter()
2352 .map(|tool| tool.name())
2353 .collect();
2354
2355 self.message_feedback.insert(message_id, feedback);
2356
2357 cx.notify();
2358
2359 let message_content = self
2360 .message(message_id)
2361 .map(|msg| msg.to_string())
2362 .unwrap_or_default();
2363
2364 cx.background_spawn(async move {
2365 let final_project_snapshot = final_project_snapshot.await;
2366 let serialized_thread = serialized_thread.await?;
2367 let thread_data =
2368 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2369
2370 let rating = match feedback {
2371 ThreadFeedback::Positive => "positive",
2372 ThreadFeedback::Negative => "negative",
2373 };
2374 telemetry::event!(
2375 "Assistant Thread Rated",
2376 rating,
2377 thread_id,
2378 enabled_tool_names,
2379 message_id = message_id.0,
2380 message_content,
2381 thread_data,
2382 final_project_snapshot
2383 );
2384 client.telemetry().flush_events().await;
2385
2386 Ok(())
2387 })
2388 }
2389
2390 pub fn report_feedback(
2391 &mut self,
2392 feedback: ThreadFeedback,
2393 cx: &mut Context<Self>,
2394 ) -> Task<Result<()>> {
2395 let last_assistant_message_id = self
2396 .messages
2397 .iter()
2398 .rev()
2399 .find(|msg| msg.role == Role::Assistant)
2400 .map(|msg| msg.id);
2401
2402 if let Some(message_id) = last_assistant_message_id {
2403 self.report_message_feedback(message_id, feedback, cx)
2404 } else {
2405 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2406 let serialized_thread = self.serialize(cx);
2407 let thread_id = self.id().clone();
2408 let client = self.project.read(cx).client();
2409 self.feedback = Some(feedback);
2410 cx.notify();
2411
2412 cx.background_spawn(async move {
2413 let final_project_snapshot = final_project_snapshot.await;
2414 let serialized_thread = serialized_thread.await?;
2415 let thread_data = serde_json::to_value(serialized_thread)
2416 .unwrap_or_else(|_| serde_json::Value::Null);
2417
2418 let rating = match feedback {
2419 ThreadFeedback::Positive => "positive",
2420 ThreadFeedback::Negative => "negative",
2421 };
2422 telemetry::event!(
2423 "Assistant Thread Rated",
2424 rating,
2425 thread_id,
2426 thread_data,
2427 final_project_snapshot
2428 );
2429 client.telemetry().flush_events().await;
2430
2431 Ok(())
2432 })
2433 }
2434 }
2435
2436 /// Create a snapshot of the current project state including git information and unsaved buffers.
2437 fn project_snapshot(
2438 project: Entity<Project>,
2439 cx: &mut Context<Self>,
2440 ) -> Task<Arc<ProjectSnapshot>> {
2441 let git_store = project.read(cx).git_store().clone();
2442 let worktree_snapshots: Vec<_> = project
2443 .read(cx)
2444 .visible_worktrees(cx)
2445 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2446 .collect();
2447
2448 cx.spawn(async move |_, cx| {
2449 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2450
2451 let mut unsaved_buffers = Vec::new();
2452 cx.update(|app_cx| {
2453 let buffer_store = project.read(app_cx).buffer_store();
2454 for buffer_handle in buffer_store.read(app_cx).buffers() {
2455 let buffer = buffer_handle.read(app_cx);
2456 if buffer.is_dirty() {
2457 if let Some(file) = buffer.file() {
2458 let path = file.path().to_string_lossy().to_string();
2459 unsaved_buffers.push(path);
2460 }
2461 }
2462 }
2463 })
2464 .ok();
2465
2466 Arc::new(ProjectSnapshot {
2467 worktree_snapshots,
2468 unsaved_buffer_paths: unsaved_buffers,
2469 timestamp: Utc::now(),
2470 })
2471 })
2472 }
2473
2474 fn worktree_snapshot(
2475 worktree: Entity<project::Worktree>,
2476 git_store: Entity<GitStore>,
2477 cx: &App,
2478 ) -> Task<WorktreeSnapshot> {
2479 cx.spawn(async move |cx| {
2480 // Get worktree path and snapshot
2481 let worktree_info = cx.update(|app_cx| {
2482 let worktree = worktree.read(app_cx);
2483 let path = worktree.abs_path().to_string_lossy().to_string();
2484 let snapshot = worktree.snapshot();
2485 (path, snapshot)
2486 });
2487
2488 let Ok((worktree_path, _snapshot)) = worktree_info else {
2489 return WorktreeSnapshot {
2490 worktree_path: String::new(),
2491 git_state: None,
2492 };
2493 };
2494
2495 let git_state = git_store
2496 .update(cx, |git_store, cx| {
2497 git_store
2498 .repositories()
2499 .values()
2500 .find(|repo| {
2501 repo.read(cx)
2502 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2503 .is_some()
2504 })
2505 .cloned()
2506 })
2507 .ok()
2508 .flatten()
2509 .map(|repo| {
2510 repo.update(cx, |repo, _| {
2511 let current_branch =
2512 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2513 repo.send_job(None, |state, _| async move {
2514 let RepositoryState::Local { backend, .. } = state else {
2515 return GitState {
2516 remote_url: None,
2517 head_sha: None,
2518 current_branch,
2519 diff: None,
2520 };
2521 };
2522
2523 let remote_url = backend.remote_url("origin");
2524 let head_sha = backend.head_sha().await;
2525 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2526
2527 GitState {
2528 remote_url,
2529 head_sha,
2530 current_branch,
2531 diff,
2532 }
2533 })
2534 })
2535 });
2536
2537 let git_state = match git_state {
2538 Some(git_state) => match git_state.ok() {
2539 Some(git_state) => git_state.await.ok(),
2540 None => None,
2541 },
2542 None => None,
2543 };
2544
2545 WorktreeSnapshot {
2546 worktree_path,
2547 git_state,
2548 }
2549 })
2550 }
2551
2552 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2553 let mut markdown = Vec::new();
2554
2555 let summary = self.summary().or_default();
2556 writeln!(markdown, "# {summary}\n")?;
2557
2558 for message in self.messages() {
2559 writeln!(
2560 markdown,
2561 "## {role}\n",
2562 role = match message.role {
2563 Role::User => "User",
2564 Role::Assistant => "Agent",
2565 Role::System => "System",
2566 }
2567 )?;
2568
2569 if !message.loaded_context.text.is_empty() {
2570 writeln!(markdown, "{}", message.loaded_context.text)?;
2571 }
2572
2573 if !message.loaded_context.images.is_empty() {
2574 writeln!(
2575 markdown,
2576 "\n{} images attached as context.\n",
2577 message.loaded_context.images.len()
2578 )?;
2579 }
2580
2581 for segment in &message.segments {
2582 match segment {
2583 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2584 MessageSegment::Thinking { text, .. } => {
2585 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2586 }
2587 MessageSegment::RedactedThinking(_) => {}
2588 }
2589 }
2590
2591 for tool_use in self.tool_uses_for_message(message.id, cx) {
2592 writeln!(
2593 markdown,
2594 "**Use Tool: {} ({})**",
2595 tool_use.name, tool_use.id
2596 )?;
2597 writeln!(markdown, "```json")?;
2598 writeln!(
2599 markdown,
2600 "{}",
2601 serde_json::to_string_pretty(&tool_use.input)?
2602 )?;
2603 writeln!(markdown, "```")?;
2604 }
2605
2606 for tool_result in self.tool_results_for_message(message.id) {
2607 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2608 if tool_result.is_error {
2609 write!(markdown, " (Error)")?;
2610 }
2611
2612 writeln!(markdown, "**\n")?;
2613 match &tool_result.content {
2614 LanguageModelToolResultContent::Text(text)
2615 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2616 text,
2617 ..
2618 }) => {
2619 writeln!(markdown, "{text}")?;
2620 }
2621 LanguageModelToolResultContent::Image(image) => {
2622 writeln!(markdown, "", image.source)?;
2623 }
2624 }
2625
2626 if let Some(output) = tool_result.output.as_ref() {
2627 writeln!(
2628 markdown,
2629 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2630 serde_json::to_string_pretty(output)?
2631 )?;
2632 }
2633 }
2634 }
2635
2636 Ok(String::from_utf8_lossy(&markdown).to_string())
2637 }
2638
2639 pub fn keep_edits_in_range(
2640 &mut self,
2641 buffer: Entity<language::Buffer>,
2642 buffer_range: Range<language::Anchor>,
2643 cx: &mut Context<Self>,
2644 ) {
2645 self.action_log.update(cx, |action_log, cx| {
2646 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2647 });
2648 }
2649
2650 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2651 self.action_log
2652 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2653 }
2654
2655 pub fn reject_edits_in_ranges(
2656 &mut self,
2657 buffer: Entity<language::Buffer>,
2658 buffer_ranges: Vec<Range<language::Anchor>>,
2659 cx: &mut Context<Self>,
2660 ) -> Task<Result<()>> {
2661 self.action_log.update(cx, |action_log, cx| {
2662 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2663 })
2664 }
2665
2666 pub fn action_log(&self) -> &Entity<ActionLog> {
2667 &self.action_log
2668 }
2669
2670 pub fn project(&self) -> &Entity<Project> {
2671 &self.project
2672 }
2673
2674 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2675 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2676 return;
2677 }
2678
2679 let now = Instant::now();
2680 if let Some(last) = self.last_auto_capture_at {
2681 if now.duration_since(last).as_secs() < 10 {
2682 return;
2683 }
2684 }
2685
2686 self.last_auto_capture_at = Some(now);
2687
2688 let thread_id = self.id().clone();
2689 let github_login = self
2690 .project
2691 .read(cx)
2692 .user_store()
2693 .read(cx)
2694 .current_user()
2695 .map(|user| user.github_login.clone());
2696 let client = self.project.read(cx).client();
2697 let serialize_task = self.serialize(cx);
2698
2699 cx.background_executor()
2700 .spawn(async move {
2701 if let Ok(serialized_thread) = serialize_task.await {
2702 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2703 telemetry::event!(
2704 "Agent Thread Auto-Captured",
2705 thread_id = thread_id.to_string(),
2706 thread_data = thread_data,
2707 auto_capture_reason = "tracked_user",
2708 github_login = github_login
2709 );
2710
2711 client.telemetry().flush_events().await;
2712 }
2713 }
2714 })
2715 .detach();
2716 }
2717
2718 pub fn cumulative_token_usage(&self) -> TokenUsage {
2719 self.cumulative_token_usage
2720 }
2721
2722 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2723 let Some(model) = self.configured_model.as_ref() else {
2724 return TotalTokenUsage::default();
2725 };
2726
2727 let max = model.model.max_token_count();
2728
2729 let index = self
2730 .messages
2731 .iter()
2732 .position(|msg| msg.id == message_id)
2733 .unwrap_or(0);
2734
2735 if index == 0 {
2736 return TotalTokenUsage { total: 0, max };
2737 }
2738
2739 let token_usage = &self
2740 .request_token_usage
2741 .get(index - 1)
2742 .cloned()
2743 .unwrap_or_default();
2744
2745 TotalTokenUsage {
2746 total: token_usage.total_tokens() as usize,
2747 max,
2748 }
2749 }
2750
2751 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2752 let model = self.configured_model.as_ref()?;
2753
2754 let max = model.model.max_token_count();
2755
2756 if let Some(exceeded_error) = &self.exceeded_window_error {
2757 if model.model.id() == exceeded_error.model_id {
2758 return Some(TotalTokenUsage {
2759 total: exceeded_error.token_count,
2760 max,
2761 });
2762 }
2763 }
2764
2765 let total = self
2766 .token_usage_at_last_message()
2767 .unwrap_or_default()
2768 .total_tokens() as usize;
2769
2770 Some(TotalTokenUsage { total, max })
2771 }
2772
2773 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2774 self.request_token_usage
2775 .get(self.messages.len().saturating_sub(1))
2776 .or_else(|| self.request_token_usage.last())
2777 .cloned()
2778 }
2779
2780 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2781 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2782 self.request_token_usage
2783 .resize(self.messages.len(), placeholder);
2784
2785 if let Some(last) = self.request_token_usage.last_mut() {
2786 *last = token_usage;
2787 }
2788 }
2789
2790 pub fn deny_tool_use(
2791 &mut self,
2792 tool_use_id: LanguageModelToolUseId,
2793 tool_name: Arc<str>,
2794 window: Option<AnyWindowHandle>,
2795 cx: &mut Context<Self>,
2796 ) {
2797 let err = Err(anyhow::anyhow!(
2798 "Permission to run tool action denied by user"
2799 ));
2800
2801 self.tool_use.insert_tool_output(
2802 tool_use_id.clone(),
2803 tool_name,
2804 err,
2805 self.configured_model.as_ref(),
2806 );
2807 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2808 }
2809}
2810
2811#[derive(Debug, Clone, Error)]
2812pub enum ThreadError {
2813 #[error("Payment required")]
2814 PaymentRequired,
2815 #[error("Model request limit reached")]
2816 ModelRequestLimitReached { plan: Plan },
2817 #[error("Message {header}: {message}")]
2818 Message {
2819 header: SharedString,
2820 message: SharedString,
2821 },
2822}
2823
2824#[derive(Debug, Clone)]
2825pub enum ThreadEvent {
2826 ShowError(ThreadError),
2827 StreamedCompletion,
2828 ReceivedTextChunk,
2829 NewRequest,
2830 StreamedAssistantText(MessageId, String),
2831 StreamedAssistantThinking(MessageId, String),
2832 StreamedToolUse {
2833 tool_use_id: LanguageModelToolUseId,
2834 ui_text: Arc<str>,
2835 input: serde_json::Value,
2836 },
2837 MissingToolUse {
2838 tool_use_id: LanguageModelToolUseId,
2839 ui_text: Arc<str>,
2840 },
2841 InvalidToolInput {
2842 tool_use_id: LanguageModelToolUseId,
2843 ui_text: Arc<str>,
2844 invalid_input_json: Arc<str>,
2845 },
2846 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2847 MessageAdded(MessageId),
2848 MessageEdited(MessageId),
2849 MessageDeleted(MessageId),
2850 SummaryGenerated,
2851 SummaryChanged,
2852 UsePendingTools {
2853 tool_uses: Vec<PendingToolUse>,
2854 },
2855 ToolFinished {
2856 #[allow(unused)]
2857 tool_use_id: LanguageModelToolUseId,
2858 /// The pending tool use that corresponds to this tool.
2859 pending_tool_use: Option<PendingToolUse>,
2860 },
2861 CheckpointChanged,
2862 ToolConfirmationNeeded,
2863 CancelEditing,
2864 CompletionCanceled,
2865}
2866
2867impl EventEmitter<ThreadEvent> for Thread {}
2868
2869struct PendingCompletion {
2870 id: usize,
2871 queue_state: QueueState,
2872 _task: Task<()>,
2873}
2874
2875#[cfg(test)]
2876mod tests {
2877 use super::*;
2878 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2879 use agent_settings::{AgentSettings, LanguageModelParameters};
2880 use assistant_tool::ToolRegistry;
2881 use editor::EditorSettings;
2882 use gpui::TestAppContext;
2883 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2884 use project::{FakeFs, Project};
2885 use prompt_store::PromptBuilder;
2886 use serde_json::json;
2887 use settings::{Settings, SettingsStore};
2888 use std::sync::Arc;
2889 use theme::ThemeSettings;
2890 use util::path;
2891 use workspace::Workspace;
2892
2893 #[gpui::test]
2894 async fn test_message_with_context(cx: &mut TestAppContext) {
2895 init_test_settings(cx);
2896
2897 let project = create_test_project(
2898 cx,
2899 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2900 )
2901 .await;
2902
2903 let (_workspace, _thread_store, thread, context_store, model) =
2904 setup_test_environment(cx, project.clone()).await;
2905
2906 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2907 .await
2908 .unwrap();
2909
2910 let context =
2911 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2912 let loaded_context = cx
2913 .update(|cx| load_context(vec![context], &project, &None, cx))
2914 .await;
2915
2916 // Insert user message with context
2917 let message_id = thread.update(cx, |thread, cx| {
2918 thread.insert_user_message(
2919 "Please explain this code",
2920 loaded_context,
2921 None,
2922 Vec::new(),
2923 cx,
2924 )
2925 });
2926
2927 // Check content and context in message object
2928 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2929
2930 // Use different path format strings based on platform for the test
2931 #[cfg(windows)]
2932 let path_part = r"test\code.rs";
2933 #[cfg(not(windows))]
2934 let path_part = "test/code.rs";
2935
2936 let expected_context = format!(
2937 r#"
2938<context>
2939The following items were attached by the user. They are up-to-date and don't need to be re-read.
2940
2941<files>
2942```rs {path_part}
2943fn main() {{
2944 println!("Hello, world!");
2945}}
2946```
2947</files>
2948</context>
2949"#
2950 );
2951
2952 assert_eq!(message.role, Role::User);
2953 assert_eq!(message.segments.len(), 1);
2954 assert_eq!(
2955 message.segments[0],
2956 MessageSegment::Text("Please explain this code".to_string())
2957 );
2958 assert_eq!(message.loaded_context.text, expected_context);
2959
2960 // Check message in request
2961 let request = thread.update(cx, |thread, cx| {
2962 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2963 });
2964
2965 assert_eq!(request.messages.len(), 2);
2966 let expected_full_message = format!("{}Please explain this code", expected_context);
2967 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2968 }
2969
2970 #[gpui::test]
2971 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2972 init_test_settings(cx);
2973
2974 let project = create_test_project(
2975 cx,
2976 json!({
2977 "file1.rs": "fn function1() {}\n",
2978 "file2.rs": "fn function2() {}\n",
2979 "file3.rs": "fn function3() {}\n",
2980 "file4.rs": "fn function4() {}\n",
2981 }),
2982 )
2983 .await;
2984
2985 let (_, _thread_store, thread, context_store, model) =
2986 setup_test_environment(cx, project.clone()).await;
2987
2988 // First message with context 1
2989 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2990 .await
2991 .unwrap();
2992 let new_contexts = context_store.update(cx, |store, cx| {
2993 store.new_context_for_thread(thread.read(cx), None)
2994 });
2995 assert_eq!(new_contexts.len(), 1);
2996 let loaded_context = cx
2997 .update(|cx| load_context(new_contexts, &project, &None, cx))
2998 .await;
2999 let message1_id = thread.update(cx, |thread, cx| {
3000 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3001 });
3002
3003 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3004 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3005 .await
3006 .unwrap();
3007 let new_contexts = context_store.update(cx, |store, cx| {
3008 store.new_context_for_thread(thread.read(cx), None)
3009 });
3010 assert_eq!(new_contexts.len(), 1);
3011 let loaded_context = cx
3012 .update(|cx| load_context(new_contexts, &project, &None, cx))
3013 .await;
3014 let message2_id = thread.update(cx, |thread, cx| {
3015 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3016 });
3017
3018 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3019 //
3020 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3021 .await
3022 .unwrap();
3023 let new_contexts = context_store.update(cx, |store, cx| {
3024 store.new_context_for_thread(thread.read(cx), None)
3025 });
3026 assert_eq!(new_contexts.len(), 1);
3027 let loaded_context = cx
3028 .update(|cx| load_context(new_contexts, &project, &None, cx))
3029 .await;
3030 let message3_id = thread.update(cx, |thread, cx| {
3031 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3032 });
3033
3034 // Check what contexts are included in each message
3035 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3036 (
3037 thread.message(message1_id).unwrap().clone(),
3038 thread.message(message2_id).unwrap().clone(),
3039 thread.message(message3_id).unwrap().clone(),
3040 )
3041 });
3042
3043 // First message should include context 1
3044 assert!(message1.loaded_context.text.contains("file1.rs"));
3045
3046 // Second message should include only context 2 (not 1)
3047 assert!(!message2.loaded_context.text.contains("file1.rs"));
3048 assert!(message2.loaded_context.text.contains("file2.rs"));
3049
3050 // Third message should include only context 3 (not 1 or 2)
3051 assert!(!message3.loaded_context.text.contains("file1.rs"));
3052 assert!(!message3.loaded_context.text.contains("file2.rs"));
3053 assert!(message3.loaded_context.text.contains("file3.rs"));
3054
3055 // Check entire request to make sure all contexts are properly included
3056 let request = thread.update(cx, |thread, cx| {
3057 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3058 });
3059
3060 // The request should contain all 3 messages
3061 assert_eq!(request.messages.len(), 4);
3062
3063 // Check that the contexts are properly formatted in each message
3064 assert!(request.messages[1].string_contents().contains("file1.rs"));
3065 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3066 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3067
3068 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3069 assert!(request.messages[2].string_contents().contains("file2.rs"));
3070 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3071
3072 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3073 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3074 assert!(request.messages[3].string_contents().contains("file3.rs"));
3075
3076 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3077 .await
3078 .unwrap();
3079 let new_contexts = context_store.update(cx, |store, cx| {
3080 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3081 });
3082 assert_eq!(new_contexts.len(), 3);
3083 let loaded_context = cx
3084 .update(|cx| load_context(new_contexts, &project, &None, cx))
3085 .await
3086 .loaded_context;
3087
3088 assert!(!loaded_context.text.contains("file1.rs"));
3089 assert!(loaded_context.text.contains("file2.rs"));
3090 assert!(loaded_context.text.contains("file3.rs"));
3091 assert!(loaded_context.text.contains("file4.rs"));
3092
3093 let new_contexts = context_store.update(cx, |store, cx| {
3094 // Remove file4.rs
3095 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3096 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3097 });
3098 assert_eq!(new_contexts.len(), 2);
3099 let loaded_context = cx
3100 .update(|cx| load_context(new_contexts, &project, &None, cx))
3101 .await
3102 .loaded_context;
3103
3104 assert!(!loaded_context.text.contains("file1.rs"));
3105 assert!(loaded_context.text.contains("file2.rs"));
3106 assert!(loaded_context.text.contains("file3.rs"));
3107 assert!(!loaded_context.text.contains("file4.rs"));
3108
3109 let new_contexts = context_store.update(cx, |store, cx| {
3110 // Remove file3.rs
3111 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3112 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3113 });
3114 assert_eq!(new_contexts.len(), 1);
3115 let loaded_context = cx
3116 .update(|cx| load_context(new_contexts, &project, &None, cx))
3117 .await
3118 .loaded_context;
3119
3120 assert!(!loaded_context.text.contains("file1.rs"));
3121 assert!(loaded_context.text.contains("file2.rs"));
3122 assert!(!loaded_context.text.contains("file3.rs"));
3123 assert!(!loaded_context.text.contains("file4.rs"));
3124 }
3125
3126 #[gpui::test]
3127 async fn test_message_without_files(cx: &mut TestAppContext) {
3128 init_test_settings(cx);
3129
3130 let project = create_test_project(
3131 cx,
3132 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3133 )
3134 .await;
3135
3136 let (_, _thread_store, thread, _context_store, model) =
3137 setup_test_environment(cx, project.clone()).await;
3138
3139 // Insert user message without any context (empty context vector)
3140 let message_id = thread.update(cx, |thread, cx| {
3141 thread.insert_user_message(
3142 "What is the best way to learn Rust?",
3143 ContextLoadResult::default(),
3144 None,
3145 Vec::new(),
3146 cx,
3147 )
3148 });
3149
3150 // Check content and context in message object
3151 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3152
3153 // Context should be empty when no files are included
3154 assert_eq!(message.role, Role::User);
3155 assert_eq!(message.segments.len(), 1);
3156 assert_eq!(
3157 message.segments[0],
3158 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3159 );
3160 assert_eq!(message.loaded_context.text, "");
3161
3162 // Check message in request
3163 let request = thread.update(cx, |thread, cx| {
3164 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3165 });
3166
3167 assert_eq!(request.messages.len(), 2);
3168 assert_eq!(
3169 request.messages[1].string_contents(),
3170 "What is the best way to learn Rust?"
3171 );
3172
3173 // Add second message, also without context
3174 let message2_id = thread.update(cx, |thread, cx| {
3175 thread.insert_user_message(
3176 "Are there any good books?",
3177 ContextLoadResult::default(),
3178 None,
3179 Vec::new(),
3180 cx,
3181 )
3182 });
3183
3184 let message2 =
3185 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3186 assert_eq!(message2.loaded_context.text, "");
3187
3188 // Check that both messages appear in the request
3189 let request = thread.update(cx, |thread, cx| {
3190 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3191 });
3192
3193 assert_eq!(request.messages.len(), 3);
3194 assert_eq!(
3195 request.messages[1].string_contents(),
3196 "What is the best way to learn Rust?"
3197 );
3198 assert_eq!(
3199 request.messages[2].string_contents(),
3200 "Are there any good books?"
3201 );
3202 }
3203
3204 #[gpui::test]
3205 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3206 init_test_settings(cx);
3207
3208 let project = create_test_project(
3209 cx,
3210 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3211 )
3212 .await;
3213
3214 let (_workspace, _thread_store, thread, context_store, model) =
3215 setup_test_environment(cx, project.clone()).await;
3216
3217 // Open buffer and add it to context
3218 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3219 .await
3220 .unwrap();
3221
3222 let context =
3223 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3224 let loaded_context = cx
3225 .update(|cx| load_context(vec![context], &project, &None, cx))
3226 .await;
3227
3228 // Insert user message with the buffer as context
3229 thread.update(cx, |thread, cx| {
3230 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3231 });
3232
3233 // Create a request and check that it doesn't have a stale buffer warning yet
3234 let initial_request = thread.update(cx, |thread, cx| {
3235 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3236 });
3237
3238 // Make sure we don't have a stale file warning yet
3239 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3240 msg.string_contents()
3241 .contains("These files changed since last read:")
3242 });
3243 assert!(
3244 !has_stale_warning,
3245 "Should not have stale buffer warning before buffer is modified"
3246 );
3247
3248 // Modify the buffer
3249 buffer.update(cx, |buffer, cx| {
3250 // Find a position at the end of line 1
3251 buffer.edit(
3252 [(1..1, "\n println!(\"Added a new line\");\n")],
3253 None,
3254 cx,
3255 );
3256 });
3257
3258 // Insert another user message without context
3259 thread.update(cx, |thread, cx| {
3260 thread.insert_user_message(
3261 "What does the code do now?",
3262 ContextLoadResult::default(),
3263 None,
3264 Vec::new(),
3265 cx,
3266 )
3267 });
3268
3269 // Create a new request and check for the stale buffer warning
3270 let new_request = thread.update(cx, |thread, cx| {
3271 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3272 });
3273
3274 // We should have a stale file warning as the last message
3275 let last_message = new_request
3276 .messages
3277 .last()
3278 .expect("Request should have messages");
3279
3280 // The last message should be the stale buffer notification
3281 assert_eq!(last_message.role, Role::User);
3282
3283 // Check the exact content of the message
3284 let expected_content = "These files changed since last read:\n- code.rs\n";
3285 assert_eq!(
3286 last_message.string_contents(),
3287 expected_content,
3288 "Last message should be exactly the stale buffer notification"
3289 );
3290 }
3291
3292 #[gpui::test]
3293 async fn test_temperature_setting(cx: &mut TestAppContext) {
3294 init_test_settings(cx);
3295
3296 let project = create_test_project(
3297 cx,
3298 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3299 )
3300 .await;
3301
3302 let (_workspace, _thread_store, thread, _context_store, model) =
3303 setup_test_environment(cx, project.clone()).await;
3304
3305 // Both model and provider
3306 cx.update(|cx| {
3307 AgentSettings::override_global(
3308 AgentSettings {
3309 model_parameters: vec![LanguageModelParameters {
3310 provider: Some(model.provider_id().0.to_string().into()),
3311 model: Some(model.id().0.clone()),
3312 temperature: Some(0.66),
3313 }],
3314 ..AgentSettings::get_global(cx).clone()
3315 },
3316 cx,
3317 );
3318 });
3319
3320 let request = thread.update(cx, |thread, cx| {
3321 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3322 });
3323 assert_eq!(request.temperature, Some(0.66));
3324
3325 // Only model
3326 cx.update(|cx| {
3327 AgentSettings::override_global(
3328 AgentSettings {
3329 model_parameters: vec![LanguageModelParameters {
3330 provider: None,
3331 model: Some(model.id().0.clone()),
3332 temperature: Some(0.66),
3333 }],
3334 ..AgentSettings::get_global(cx).clone()
3335 },
3336 cx,
3337 );
3338 });
3339
3340 let request = thread.update(cx, |thread, cx| {
3341 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3342 });
3343 assert_eq!(request.temperature, Some(0.66));
3344
3345 // Only provider
3346 cx.update(|cx| {
3347 AgentSettings::override_global(
3348 AgentSettings {
3349 model_parameters: vec![LanguageModelParameters {
3350 provider: Some(model.provider_id().0.to_string().into()),
3351 model: None,
3352 temperature: Some(0.66),
3353 }],
3354 ..AgentSettings::get_global(cx).clone()
3355 },
3356 cx,
3357 );
3358 });
3359
3360 let request = thread.update(cx, |thread, cx| {
3361 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3362 });
3363 assert_eq!(request.temperature, Some(0.66));
3364
3365 // Same model name, different provider
3366 cx.update(|cx| {
3367 AgentSettings::override_global(
3368 AgentSettings {
3369 model_parameters: vec![LanguageModelParameters {
3370 provider: Some("anthropic".into()),
3371 model: Some(model.id().0.clone()),
3372 temperature: Some(0.66),
3373 }],
3374 ..AgentSettings::get_global(cx).clone()
3375 },
3376 cx,
3377 );
3378 });
3379
3380 let request = thread.update(cx, |thread, cx| {
3381 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3382 });
3383 assert_eq!(request.temperature, None);
3384 }
3385
3386 #[gpui::test]
3387 async fn test_thread_summary(cx: &mut TestAppContext) {
3388 init_test_settings(cx);
3389
3390 let project = create_test_project(cx, json!({})).await;
3391
3392 let (_, _thread_store, thread, _context_store, model) =
3393 setup_test_environment(cx, project.clone()).await;
3394
3395 // Initial state should be pending
3396 thread.read_with(cx, |thread, _| {
3397 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3398 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3399 });
3400
3401 // Manually setting the summary should not be allowed in this state
3402 thread.update(cx, |thread, cx| {
3403 thread.set_summary("This should not work", cx);
3404 });
3405
3406 thread.read_with(cx, |thread, _| {
3407 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3408 });
3409
3410 // Send a message
3411 thread.update(cx, |thread, cx| {
3412 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3413 thread.send_to_model(
3414 model.clone(),
3415 CompletionIntent::ThreadSummarization,
3416 None,
3417 cx,
3418 );
3419 });
3420
3421 let fake_model = model.as_fake();
3422 simulate_successful_response(&fake_model, cx);
3423
3424 // Should start generating summary when there are >= 2 messages
3425 thread.read_with(cx, |thread, _| {
3426 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3427 });
3428
3429 // Should not be able to set the summary while generating
3430 thread.update(cx, |thread, cx| {
3431 thread.set_summary("This should not work either", cx);
3432 });
3433
3434 thread.read_with(cx, |thread, _| {
3435 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3436 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3437 });
3438
3439 cx.run_until_parked();
3440 fake_model.stream_last_completion_response("Brief");
3441 fake_model.stream_last_completion_response(" Introduction");
3442 fake_model.end_last_completion_stream();
3443 cx.run_until_parked();
3444
3445 // Summary should be set
3446 thread.read_with(cx, |thread, _| {
3447 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3448 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3449 });
3450
3451 // Now we should be able to set a summary
3452 thread.update(cx, |thread, cx| {
3453 thread.set_summary("Brief Intro", cx);
3454 });
3455
3456 thread.read_with(cx, |thread, _| {
3457 assert_eq!(thread.summary().or_default(), "Brief Intro");
3458 });
3459
3460 // Test setting an empty summary (should default to DEFAULT)
3461 thread.update(cx, |thread, cx| {
3462 thread.set_summary("", cx);
3463 });
3464
3465 thread.read_with(cx, |thread, _| {
3466 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3467 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3468 });
3469 }
3470
3471 #[gpui::test]
3472 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3473 init_test_settings(cx);
3474
3475 let project = create_test_project(cx, json!({})).await;
3476
3477 let (_, _thread_store, thread, _context_store, model) =
3478 setup_test_environment(cx, project.clone()).await;
3479
3480 test_summarize_error(&model, &thread, cx);
3481
3482 // Now we should be able to set a summary
3483 thread.update(cx, |thread, cx| {
3484 thread.set_summary("Brief Intro", cx);
3485 });
3486
3487 thread.read_with(cx, |thread, _| {
3488 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3489 assert_eq!(thread.summary().or_default(), "Brief Intro");
3490 });
3491 }
3492
3493 #[gpui::test]
3494 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3495 init_test_settings(cx);
3496
3497 let project = create_test_project(cx, json!({})).await;
3498
3499 let (_, _thread_store, thread, _context_store, model) =
3500 setup_test_environment(cx, project.clone()).await;
3501
3502 test_summarize_error(&model, &thread, cx);
3503
3504 // Sending another message should not trigger another summarize request
3505 thread.update(cx, |thread, cx| {
3506 thread.insert_user_message(
3507 "How are you?",
3508 ContextLoadResult::default(),
3509 None,
3510 vec![],
3511 cx,
3512 );
3513 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3514 });
3515
3516 let fake_model = model.as_fake();
3517 simulate_successful_response(&fake_model, cx);
3518
3519 thread.read_with(cx, |thread, _| {
3520 // State is still Error, not Generating
3521 assert!(matches!(thread.summary(), ThreadSummary::Error));
3522 });
3523
3524 // But the summarize request can be invoked manually
3525 thread.update(cx, |thread, cx| {
3526 thread.summarize(cx);
3527 });
3528
3529 thread.read_with(cx, |thread, _| {
3530 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3531 });
3532
3533 cx.run_until_parked();
3534 fake_model.stream_last_completion_response("A successful summary");
3535 fake_model.end_last_completion_stream();
3536 cx.run_until_parked();
3537
3538 thread.read_with(cx, |thread, _| {
3539 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3540 assert_eq!(thread.summary().or_default(), "A successful summary");
3541 });
3542 }
3543
3544 fn test_summarize_error(
3545 model: &Arc<dyn LanguageModel>,
3546 thread: &Entity<Thread>,
3547 cx: &mut TestAppContext,
3548 ) {
3549 thread.update(cx, |thread, cx| {
3550 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3551 thread.send_to_model(
3552 model.clone(),
3553 CompletionIntent::ThreadSummarization,
3554 None,
3555 cx,
3556 );
3557 });
3558
3559 let fake_model = model.as_fake();
3560 simulate_successful_response(&fake_model, cx);
3561
3562 thread.read_with(cx, |thread, _| {
3563 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3564 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3565 });
3566
3567 // Simulate summary request ending
3568 cx.run_until_parked();
3569 fake_model.end_last_completion_stream();
3570 cx.run_until_parked();
3571
3572 // State is set to Error and default message
3573 thread.read_with(cx, |thread, _| {
3574 assert!(matches!(thread.summary(), ThreadSummary::Error));
3575 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3576 });
3577 }
3578
3579 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3580 cx.run_until_parked();
3581 fake_model.stream_last_completion_response("Assistant response");
3582 fake_model.end_last_completion_stream();
3583 cx.run_until_parked();
3584 }
3585
3586 fn init_test_settings(cx: &mut TestAppContext) {
3587 cx.update(|cx| {
3588 let settings_store = SettingsStore::test(cx);
3589 cx.set_global(settings_store);
3590 language::init(cx);
3591 Project::init_settings(cx);
3592 AgentSettings::register(cx);
3593 prompt_store::init(cx);
3594 thread_store::init(cx);
3595 workspace::init_settings(cx);
3596 language_model::init_settings(cx);
3597 ThemeSettings::register(cx);
3598 EditorSettings::register(cx);
3599 ToolRegistry::default_global(cx);
3600 });
3601 }
3602
3603 // Helper to create a test project with test files
3604 async fn create_test_project(
3605 cx: &mut TestAppContext,
3606 files: serde_json::Value,
3607 ) -> Entity<Project> {
3608 let fs = FakeFs::new(cx.executor());
3609 fs.insert_tree(path!("/test"), files).await;
3610 Project::test(fs, [path!("/test").as_ref()], cx).await
3611 }
3612
3613 async fn setup_test_environment(
3614 cx: &mut TestAppContext,
3615 project: Entity<Project>,
3616 ) -> (
3617 Entity<Workspace>,
3618 Entity<ThreadStore>,
3619 Entity<Thread>,
3620 Entity<ContextStore>,
3621 Arc<dyn LanguageModel>,
3622 ) {
3623 let (workspace, cx) =
3624 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3625
3626 let thread_store = cx
3627 .update(|_, cx| {
3628 ThreadStore::load(
3629 project.clone(),
3630 cx.new(|_| ToolWorkingSet::default()),
3631 None,
3632 Arc::new(PromptBuilder::new(None).unwrap()),
3633 cx,
3634 )
3635 })
3636 .await
3637 .unwrap();
3638
3639 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3640 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3641
3642 let provider = Arc::new(FakeLanguageModelProvider);
3643 let model = provider.test_model();
3644 let model: Arc<dyn LanguageModel> = Arc::new(model);
3645
3646 cx.update(|_, cx| {
3647 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3648 registry.set_default_model(
3649 Some(ConfiguredModel {
3650 provider: provider.clone(),
3651 model: model.clone(),
3652 }),
3653 cx,
3654 );
3655 registry.set_thread_summary_model(
3656 Some(ConfiguredModel {
3657 provider,
3658 model: model.clone(),
3659 }),
3660 cx,
3661 );
3662 })
3663 });
3664
3665 (workspace, thread_store, thread, context_store, model)
3666 }
3667
3668 async fn add_file_to_context(
3669 project: &Entity<Project>,
3670 context_store: &Entity<ContextStore>,
3671 path: &str,
3672 cx: &mut TestAppContext,
3673 ) -> Result<Entity<language::Buffer>> {
3674 let buffer_path = project
3675 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3676 .unwrap();
3677
3678 let buffer = project
3679 .update(cx, |project, cx| {
3680 project.open_buffer(buffer_path.clone(), cx)
3681 })
3682 .await
3683 .unwrap();
3684
3685 context_store.update(cx, |context_store, cx| {
3686 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3687 });
3688
3689 Ok(buffer)
3690 }
3691}