1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|tool_use| tool_use.status.is_error())
875 }
876
877 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
878 self.tool_use.tool_uses_for_message(id, cx)
879 }
880
881 pub fn tool_results_for_message(
882 &self,
883 assistant_message_id: MessageId,
884 ) -> Vec<&LanguageModelToolResult> {
885 self.tool_use.tool_results_for_message(assistant_message_id)
886 }
887
888 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
889 self.tool_use.tool_result(id)
890 }
891
892 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
893 match &self.tool_use.tool_result(id)?.content {
894 LanguageModelToolResultContent::Text(text) => Some(text),
895 LanguageModelToolResultContent::Image(_) => {
896 // TODO: We should display image
897 None
898 }
899 }
900 }
901
902 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
903 self.tool_use.tool_result_card(id).cloned()
904 }
905
906 /// Return tools that are both enabled and supported by the model
907 pub fn available_tools(
908 &self,
909 cx: &App,
910 model: Arc<dyn LanguageModel>,
911 ) -> Vec<LanguageModelRequestTool> {
912 if model.supports_tools() {
913 self.tools()
914 .read(cx)
915 .enabled_tools(cx)
916 .into_iter()
917 .filter_map(|tool| {
918 // Skip tools that cannot be supported
919 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
920 Some(LanguageModelRequestTool {
921 name: tool.name(),
922 description: tool.description(),
923 input_schema,
924 })
925 })
926 .collect()
927 } else {
928 Vec::default()
929 }
930 }
931
932 pub fn insert_user_message(
933 &mut self,
934 text: impl Into<String>,
935 loaded_context: ContextLoadResult,
936 git_checkpoint: Option<GitStoreCheckpoint>,
937 creases: Vec<MessageCrease>,
938 cx: &mut Context<Self>,
939 ) -> MessageId {
940 if !loaded_context.referenced_buffers.is_empty() {
941 self.action_log.update(cx, |log, cx| {
942 for buffer in loaded_context.referenced_buffers {
943 log.buffer_read(buffer, cx);
944 }
945 });
946 }
947
948 let message_id = self.insert_message(
949 Role::User,
950 vec![MessageSegment::Text(text.into())],
951 loaded_context.loaded_context,
952 creases,
953 false,
954 cx,
955 );
956
957 if let Some(git_checkpoint) = git_checkpoint {
958 self.pending_checkpoint = Some(ThreadCheckpoint {
959 message_id,
960 git_checkpoint,
961 });
962 }
963
964 self.auto_capture_telemetry(cx);
965
966 message_id
967 }
968
969 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
970 let id = self.insert_message(
971 Role::User,
972 vec![MessageSegment::Text("Continue where you left off".into())],
973 LoadedContext::default(),
974 vec![],
975 true,
976 cx,
977 );
978 self.pending_checkpoint = None;
979
980 id
981 }
982
983 pub fn insert_assistant_message(
984 &mut self,
985 segments: Vec<MessageSegment>,
986 cx: &mut Context<Self>,
987 ) -> MessageId {
988 self.insert_message(
989 Role::Assistant,
990 segments,
991 LoadedContext::default(),
992 Vec::new(),
993 false,
994 cx,
995 )
996 }
997
998 pub fn insert_message(
999 &mut self,
1000 role: Role,
1001 segments: Vec<MessageSegment>,
1002 loaded_context: LoadedContext,
1003 creases: Vec<MessageCrease>,
1004 is_hidden: bool,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 let id = self.next_message_id.post_inc();
1008 self.messages.push(Message {
1009 id,
1010 role,
1011 segments,
1012 loaded_context,
1013 creases,
1014 is_hidden,
1015 });
1016 self.touch_updated_at();
1017 cx.emit(ThreadEvent::MessageAdded(id));
1018 id
1019 }
1020
1021 pub fn edit_message(
1022 &mut self,
1023 id: MessageId,
1024 new_role: Role,
1025 new_segments: Vec<MessageSegment>,
1026 loaded_context: Option<LoadedContext>,
1027 checkpoint: Option<GitStoreCheckpoint>,
1028 cx: &mut Context<Self>,
1029 ) -> bool {
1030 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1031 return false;
1032 };
1033 message.role = new_role;
1034 message.segments = new_segments;
1035 if let Some(context) = loaded_context {
1036 message.loaded_context = context;
1037 }
1038 if let Some(git_checkpoint) = checkpoint {
1039 self.checkpoints_by_message.insert(
1040 id,
1041 ThreadCheckpoint {
1042 message_id: id,
1043 git_checkpoint,
1044 },
1045 );
1046 }
1047 self.touch_updated_at();
1048 cx.emit(ThreadEvent::MessageEdited(id));
1049 true
1050 }
1051
1052 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1053 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1054 return false;
1055 };
1056 self.messages.remove(index);
1057 self.touch_updated_at();
1058 cx.emit(ThreadEvent::MessageDeleted(id));
1059 true
1060 }
1061
1062 /// Returns the representation of this [`Thread`] in a textual form.
1063 ///
1064 /// This is the representation we use when attaching a thread as context to another thread.
1065 pub fn text(&self) -> String {
1066 let mut text = String::new();
1067
1068 for message in &self.messages {
1069 text.push_str(match message.role {
1070 language_model::Role::User => "User:",
1071 language_model::Role::Assistant => "Agent:",
1072 language_model::Role::System => "System:",
1073 });
1074 text.push('\n');
1075
1076 for segment in &message.segments {
1077 match segment {
1078 MessageSegment::Text(content) => text.push_str(content),
1079 MessageSegment::Thinking { text: content, .. } => {
1080 text.push_str(&format!("<think>{}</think>", content))
1081 }
1082 MessageSegment::RedactedThinking(_) => {}
1083 }
1084 }
1085 text.push('\n');
1086 }
1087
1088 text
1089 }
1090
1091 /// Serializes this thread into a format for storage or telemetry.
1092 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1093 let initial_project_snapshot = self.initial_project_snapshot.clone();
1094 cx.spawn(async move |this, cx| {
1095 let initial_project_snapshot = initial_project_snapshot.await;
1096 this.read_with(cx, |this, cx| SerializedThread {
1097 version: SerializedThread::VERSION.to_string(),
1098 summary: this.summary().or_default(),
1099 updated_at: this.updated_at(),
1100 messages: this
1101 .messages()
1102 .map(|message| SerializedMessage {
1103 id: message.id,
1104 role: message.role,
1105 segments: message
1106 .segments
1107 .iter()
1108 .map(|segment| match segment {
1109 MessageSegment::Text(text) => {
1110 SerializedMessageSegment::Text { text: text.clone() }
1111 }
1112 MessageSegment::Thinking { text, signature } => {
1113 SerializedMessageSegment::Thinking {
1114 text: text.clone(),
1115 signature: signature.clone(),
1116 }
1117 }
1118 MessageSegment::RedactedThinking(data) => {
1119 SerializedMessageSegment::RedactedThinking {
1120 data: data.clone(),
1121 }
1122 }
1123 })
1124 .collect(),
1125 tool_uses: this
1126 .tool_uses_for_message(message.id, cx)
1127 .into_iter()
1128 .map(|tool_use| SerializedToolUse {
1129 id: tool_use.id,
1130 name: tool_use.name,
1131 input: tool_use.input,
1132 })
1133 .collect(),
1134 tool_results: this
1135 .tool_results_for_message(message.id)
1136 .into_iter()
1137 .map(|tool_result| SerializedToolResult {
1138 tool_use_id: tool_result.tool_use_id.clone(),
1139 is_error: tool_result.is_error,
1140 content: tool_result.content.clone(),
1141 output: tool_result.output.clone(),
1142 })
1143 .collect(),
1144 context: message.loaded_context.text.clone(),
1145 creases: message
1146 .creases
1147 .iter()
1148 .map(|crease| SerializedCrease {
1149 start: crease.range.start,
1150 end: crease.range.end,
1151 icon_path: crease.metadata.icon_path.clone(),
1152 label: crease.metadata.label.clone(),
1153 })
1154 .collect(),
1155 is_hidden: message.is_hidden,
1156 })
1157 .collect(),
1158 initial_project_snapshot,
1159 cumulative_token_usage: this.cumulative_token_usage,
1160 request_token_usage: this.request_token_usage.clone(),
1161 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1162 exceeded_window_error: this.exceeded_window_error.clone(),
1163 model: this
1164 .configured_model
1165 .as_ref()
1166 .map(|model| SerializedLanguageModel {
1167 provider: model.provider.id().0.to_string(),
1168 model: model.model.id().0.to_string(),
1169 }),
1170 completion_mode: Some(this.completion_mode),
1171 tool_use_limit_reached: this.tool_use_limit_reached,
1172 })
1173 })
1174 }
1175
1176 pub fn remaining_turns(&self) -> u32 {
1177 self.remaining_turns
1178 }
1179
1180 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1181 self.remaining_turns = remaining_turns;
1182 }
1183
1184 pub fn send_to_model(
1185 &mut self,
1186 model: Arc<dyn LanguageModel>,
1187 window: Option<AnyWindowHandle>,
1188 cx: &mut Context<Self>,
1189 ) {
1190 if self.remaining_turns == 0 {
1191 return;
1192 }
1193
1194 self.remaining_turns -= 1;
1195
1196 let request = self.to_completion_request(model.clone(), cx);
1197
1198 self.stream_completion(request, model, window, cx);
1199 }
1200
1201 pub fn used_tools_since_last_user_message(&self) -> bool {
1202 for message in self.messages.iter().rev() {
1203 if self.tool_use.message_has_tool_results(message.id) {
1204 return true;
1205 } else if message.role == Role::User {
1206 return false;
1207 }
1208 }
1209
1210 false
1211 }
1212
1213 pub fn to_completion_request(
1214 &self,
1215 model: Arc<dyn LanguageModel>,
1216 cx: &mut Context<Self>,
1217 ) -> LanguageModelRequest {
1218 let mut request = LanguageModelRequest {
1219 thread_id: Some(self.id.to_string()),
1220 prompt_id: Some(self.last_prompt_id.to_string()),
1221 mode: None,
1222 messages: vec![],
1223 tools: Vec::new(),
1224 tool_choice: None,
1225 stop: Vec::new(),
1226 temperature: AgentSettings::temperature_for_model(&model, cx),
1227 };
1228
1229 let available_tools = self.available_tools(cx, model.clone());
1230 let available_tool_names = available_tools
1231 .iter()
1232 .map(|tool| tool.name.clone())
1233 .collect();
1234
1235 let model_context = &ModelContext {
1236 available_tools: available_tool_names,
1237 };
1238
1239 if let Some(project_context) = self.project_context.borrow().as_ref() {
1240 match self
1241 .prompt_builder
1242 .generate_assistant_system_prompt(project_context, model_context)
1243 {
1244 Err(err) => {
1245 let message = format!("{err:?}").into();
1246 log::error!("{message}");
1247 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1248 header: "Error generating system prompt".into(),
1249 message,
1250 }));
1251 }
1252 Ok(system_prompt) => {
1253 request.messages.push(LanguageModelRequestMessage {
1254 role: Role::System,
1255 content: vec![MessageContent::Text(system_prompt)],
1256 cache: true,
1257 });
1258 }
1259 }
1260 } else {
1261 let message = "Context for system prompt unexpectedly not ready.".into();
1262 log::error!("{message}");
1263 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1264 header: "Error generating system prompt".into(),
1265 message,
1266 }));
1267 }
1268
1269 let mut message_ix_to_cache = None;
1270 for message in &self.messages {
1271 let mut request_message = LanguageModelRequestMessage {
1272 role: message.role,
1273 content: Vec::new(),
1274 cache: false,
1275 };
1276
1277 message
1278 .loaded_context
1279 .add_to_request_message(&mut request_message);
1280
1281 for segment in &message.segments {
1282 match segment {
1283 MessageSegment::Text(text) => {
1284 if !text.is_empty() {
1285 request_message
1286 .content
1287 .push(MessageContent::Text(text.into()));
1288 }
1289 }
1290 MessageSegment::Thinking { text, signature } => {
1291 if !text.is_empty() {
1292 request_message.content.push(MessageContent::Thinking {
1293 text: text.into(),
1294 signature: signature.clone(),
1295 });
1296 }
1297 }
1298 MessageSegment::RedactedThinking(data) => {
1299 request_message
1300 .content
1301 .push(MessageContent::RedactedThinking(data.clone()));
1302 }
1303 };
1304 }
1305
1306 let mut cache_message = true;
1307 let mut tool_results_message = LanguageModelRequestMessage {
1308 role: Role::User,
1309 content: Vec::new(),
1310 cache: false,
1311 };
1312 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1313 if let Some(tool_result) = tool_result {
1314 request_message
1315 .content
1316 .push(MessageContent::ToolUse(tool_use.clone()));
1317 tool_results_message
1318 .content
1319 .push(MessageContent::ToolResult(LanguageModelToolResult {
1320 tool_use_id: tool_use.id.clone(),
1321 tool_name: tool_result.tool_name.clone(),
1322 is_error: tool_result.is_error,
1323 content: if tool_result.content.is_empty() {
1324 // Surprisingly, the API fails if we return an empty string here.
1325 // It thinks we are sending a tool use without a tool result.
1326 "<Tool returned an empty string>".into()
1327 } else {
1328 tool_result.content.clone()
1329 },
1330 output: None,
1331 }));
1332 } else {
1333 cache_message = false;
1334 log::debug!(
1335 "skipped tool use {:?} because it is still pending",
1336 tool_use
1337 );
1338 }
1339 }
1340
1341 if cache_message {
1342 message_ix_to_cache = Some(request.messages.len());
1343 }
1344 request.messages.push(request_message);
1345
1346 if !tool_results_message.content.is_empty() {
1347 if cache_message {
1348 message_ix_to_cache = Some(request.messages.len());
1349 }
1350 request.messages.push(tool_results_message);
1351 }
1352 }
1353
1354 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1355 if let Some(message_ix_to_cache) = message_ix_to_cache {
1356 request.messages[message_ix_to_cache].cache = true;
1357 }
1358
1359 self.attached_tracked_files_state(&mut request.messages, cx);
1360
1361 request.tools = available_tools;
1362 request.mode = if model.supports_max_mode() {
1363 Some(self.completion_mode.into())
1364 } else {
1365 Some(CompletionMode::Normal.into())
1366 };
1367
1368 request
1369 }
1370
1371 fn to_summarize_request(
1372 &self,
1373 model: &Arc<dyn LanguageModel>,
1374 added_user_message: String,
1375 cx: &App,
1376 ) -> LanguageModelRequest {
1377 let mut request = LanguageModelRequest {
1378 thread_id: None,
1379 prompt_id: None,
1380 mode: None,
1381 messages: vec![],
1382 tools: Vec::new(),
1383 tool_choice: None,
1384 stop: Vec::new(),
1385 temperature: AgentSettings::temperature_for_model(model, cx),
1386 };
1387
1388 for message in &self.messages {
1389 let mut request_message = LanguageModelRequestMessage {
1390 role: message.role,
1391 content: Vec::new(),
1392 cache: false,
1393 };
1394
1395 for segment in &message.segments {
1396 match segment {
1397 MessageSegment::Text(text) => request_message
1398 .content
1399 .push(MessageContent::Text(text.clone())),
1400 MessageSegment::Thinking { .. } => {}
1401 MessageSegment::RedactedThinking(_) => {}
1402 }
1403 }
1404
1405 if request_message.content.is_empty() {
1406 continue;
1407 }
1408
1409 request.messages.push(request_message);
1410 }
1411
1412 request.messages.push(LanguageModelRequestMessage {
1413 role: Role::User,
1414 content: vec![MessageContent::Text(added_user_message)],
1415 cache: false,
1416 });
1417
1418 request
1419 }
1420
1421 fn attached_tracked_files_state(
1422 &self,
1423 messages: &mut Vec<LanguageModelRequestMessage>,
1424 cx: &App,
1425 ) {
1426 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1427
1428 let mut stale_message = String::new();
1429
1430 let action_log = self.action_log.read(cx);
1431
1432 for stale_file in action_log.stale_buffers(cx) {
1433 let Some(file) = stale_file.read(cx).file() else {
1434 continue;
1435 };
1436
1437 if stale_message.is_empty() {
1438 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1439 }
1440
1441 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1442 }
1443
1444 let mut content = Vec::with_capacity(2);
1445
1446 if !stale_message.is_empty() {
1447 content.push(stale_message.into());
1448 }
1449
1450 if !content.is_empty() {
1451 let context_message = LanguageModelRequestMessage {
1452 role: Role::User,
1453 content,
1454 cache: false,
1455 };
1456
1457 messages.push(context_message);
1458 }
1459 }
1460
1461 pub fn stream_completion(
1462 &mut self,
1463 request: LanguageModelRequest,
1464 model: Arc<dyn LanguageModel>,
1465 window: Option<AnyWindowHandle>,
1466 cx: &mut Context<Self>,
1467 ) {
1468 self.tool_use_limit_reached = false;
1469
1470 let pending_completion_id = post_inc(&mut self.completion_count);
1471 let mut request_callback_parameters = if self.request_callback.is_some() {
1472 Some((request.clone(), Vec::new()))
1473 } else {
1474 None
1475 };
1476 let prompt_id = self.last_prompt_id.clone();
1477 let tool_use_metadata = ToolUseMetadata {
1478 model: model.clone(),
1479 thread_id: self.id.clone(),
1480 prompt_id: prompt_id.clone(),
1481 };
1482
1483 self.last_received_chunk_at = Some(Instant::now());
1484
1485 let task = cx.spawn(async move |thread, cx| {
1486 let stream_completion_future = model.stream_completion(request, &cx);
1487 let initial_token_usage =
1488 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1489 let stream_completion = async {
1490 let mut events = stream_completion_future.await?;
1491
1492 let mut stop_reason = StopReason::EndTurn;
1493 let mut current_token_usage = TokenUsage::default();
1494
1495 thread
1496 .update(cx, |_thread, cx| {
1497 cx.emit(ThreadEvent::NewRequest);
1498 })
1499 .ok();
1500
1501 let mut request_assistant_message_id = None;
1502
1503 while let Some(event) = events.next().await {
1504 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1505 response_events
1506 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1507 }
1508
1509 thread.update(cx, |thread, cx| {
1510 let event = match event {
1511 Ok(event) => event,
1512 Err(LanguageModelCompletionError::BadInputJson {
1513 id,
1514 tool_name,
1515 raw_input: invalid_input_json,
1516 json_parse_error,
1517 }) => {
1518 thread.receive_invalid_tool_json(
1519 id,
1520 tool_name,
1521 invalid_input_json,
1522 json_parse_error,
1523 window,
1524 cx,
1525 );
1526 return Ok(());
1527 }
1528 Err(LanguageModelCompletionError::Other(error)) => {
1529 return Err(error);
1530 }
1531 };
1532
1533 match event {
1534 LanguageModelCompletionEvent::StartMessage { .. } => {
1535 request_assistant_message_id =
1536 Some(thread.insert_assistant_message(
1537 vec![MessageSegment::Text(String::new())],
1538 cx,
1539 ));
1540 }
1541 LanguageModelCompletionEvent::Stop(reason) => {
1542 stop_reason = reason;
1543 }
1544 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1545 thread.update_token_usage_at_last_message(token_usage);
1546 thread.cumulative_token_usage = thread.cumulative_token_usage
1547 + token_usage
1548 - current_token_usage;
1549 current_token_usage = token_usage;
1550 }
1551 LanguageModelCompletionEvent::Text(chunk) => {
1552 thread.received_chunk();
1553
1554 cx.emit(ThreadEvent::ReceivedTextChunk);
1555 if let Some(last_message) = thread.messages.last_mut() {
1556 if last_message.role == Role::Assistant
1557 && !thread.tool_use.has_tool_results(last_message.id)
1558 {
1559 last_message.push_text(&chunk);
1560 cx.emit(ThreadEvent::StreamedAssistantText(
1561 last_message.id,
1562 chunk,
1563 ));
1564 } else {
1565 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1566 // of a new Assistant response.
1567 //
1568 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1569 // will result in duplicating the text of the chunk in the rendered Markdown.
1570 request_assistant_message_id =
1571 Some(thread.insert_assistant_message(
1572 vec![MessageSegment::Text(chunk.to_string())],
1573 cx,
1574 ));
1575 };
1576 }
1577 }
1578 LanguageModelCompletionEvent::Thinking {
1579 text: chunk,
1580 signature,
1581 } => {
1582 thread.received_chunk();
1583
1584 if let Some(last_message) = thread.messages.last_mut() {
1585 if last_message.role == Role::Assistant
1586 && !thread.tool_use.has_tool_results(last_message.id)
1587 {
1588 last_message.push_thinking(&chunk, signature);
1589 cx.emit(ThreadEvent::StreamedAssistantThinking(
1590 last_message.id,
1591 chunk,
1592 ));
1593 } else {
1594 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1595 // of a new Assistant response.
1596 //
1597 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1598 // will result in duplicating the text of the chunk in the rendered Markdown.
1599 request_assistant_message_id =
1600 Some(thread.insert_assistant_message(
1601 vec![MessageSegment::Thinking {
1602 text: chunk.to_string(),
1603 signature,
1604 }],
1605 cx,
1606 ));
1607 };
1608 }
1609 }
1610 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1611 let last_assistant_message_id = request_assistant_message_id
1612 .unwrap_or_else(|| {
1613 let new_assistant_message_id =
1614 thread.insert_assistant_message(vec![], cx);
1615 request_assistant_message_id =
1616 Some(new_assistant_message_id);
1617 new_assistant_message_id
1618 });
1619
1620 let tool_use_id = tool_use.id.clone();
1621 let streamed_input = if tool_use.is_input_complete {
1622 None
1623 } else {
1624 Some((&tool_use.input).clone())
1625 };
1626
1627 let ui_text = thread.tool_use.request_tool_use(
1628 last_assistant_message_id,
1629 tool_use,
1630 tool_use_metadata.clone(),
1631 cx,
1632 );
1633
1634 if let Some(input) = streamed_input {
1635 cx.emit(ThreadEvent::StreamedToolUse {
1636 tool_use_id,
1637 ui_text,
1638 input,
1639 });
1640 }
1641 }
1642 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1643 if let Some(completion) = thread
1644 .pending_completions
1645 .iter_mut()
1646 .find(|completion| completion.id == pending_completion_id)
1647 {
1648 match status_update {
1649 CompletionRequestStatus::Queued {
1650 position,
1651 } => {
1652 completion.queue_state = QueueState::Queued { position };
1653 }
1654 CompletionRequestStatus::Started => {
1655 completion.queue_state = QueueState::Started;
1656 }
1657 CompletionRequestStatus::Failed {
1658 code, message, request_id
1659 } => {
1660 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1661 }
1662 CompletionRequestStatus::UsageUpdated {
1663 amount, limit
1664 } => {
1665 let usage = RequestUsage { limit, amount: amount as i32 };
1666
1667 thread.last_usage = Some(usage);
1668 }
1669 CompletionRequestStatus::ToolUseLimitReached => {
1670 thread.tool_use_limit_reached = true;
1671 }
1672 }
1673 }
1674 }
1675 }
1676
1677 thread.touch_updated_at();
1678 cx.emit(ThreadEvent::StreamedCompletion);
1679 cx.notify();
1680
1681 thread.auto_capture_telemetry(cx);
1682 Ok(())
1683 })??;
1684
1685 smol::future::yield_now().await;
1686 }
1687
1688 thread.update(cx, |thread, cx| {
1689 thread.last_received_chunk_at = None;
1690 thread
1691 .pending_completions
1692 .retain(|completion| completion.id != pending_completion_id);
1693
1694 // If there is a response without tool use, summarize the message. Otherwise,
1695 // allow two tool uses before summarizing.
1696 if matches!(thread.summary, ThreadSummary::Pending)
1697 && thread.messages.len() >= 2
1698 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1699 {
1700 thread.summarize(cx);
1701 }
1702 })?;
1703
1704 anyhow::Ok(stop_reason)
1705 };
1706
1707 let result = stream_completion.await;
1708
1709 thread
1710 .update(cx, |thread, cx| {
1711 thread.finalize_pending_checkpoint(cx);
1712 match result.as_ref() {
1713 Ok(stop_reason) => match stop_reason {
1714 StopReason::ToolUse => {
1715 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1716 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1717 }
1718 StopReason::EndTurn | StopReason::MaxTokens => {
1719 thread.project.update(cx, |project, cx| {
1720 project.set_agent_location(None, cx);
1721 });
1722 }
1723 StopReason::Refusal => {
1724 thread.project.update(cx, |project, cx| {
1725 project.set_agent_location(None, cx);
1726 });
1727
1728 // Remove the turn that was refused.
1729 //
1730 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1731 {
1732 let mut messages_to_remove = Vec::new();
1733
1734 for (ix, message) in thread.messages.iter().enumerate().rev() {
1735 messages_to_remove.push(message.id);
1736
1737 if message.role == Role::User {
1738 if ix == 0 {
1739 break;
1740 }
1741
1742 if let Some(prev_message) = thread.messages.get(ix - 1) {
1743 if prev_message.role == Role::Assistant {
1744 break;
1745 }
1746 }
1747 }
1748 }
1749
1750 for message_id in messages_to_remove {
1751 thread.delete_message(message_id, cx);
1752 }
1753 }
1754
1755 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1756 header: "Language model refusal".into(),
1757 message: "Model refused to generate content for safety reasons.".into(),
1758 }));
1759 }
1760 },
1761 Err(error) => {
1762 thread.project.update(cx, |project, cx| {
1763 project.set_agent_location(None, cx);
1764 });
1765
1766 if error.is::<PaymentRequiredError>() {
1767 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1768 } else if let Some(error) =
1769 error.downcast_ref::<ModelRequestLimitReachedError>()
1770 {
1771 cx.emit(ThreadEvent::ShowError(
1772 ThreadError::ModelRequestLimitReached { plan: error.plan },
1773 ));
1774 } else if let Some(known_error) =
1775 error.downcast_ref::<LanguageModelKnownError>()
1776 {
1777 match known_error {
1778 LanguageModelKnownError::ContextWindowLimitExceeded {
1779 tokens,
1780 } => {
1781 thread.exceeded_window_error = Some(ExceededWindowError {
1782 model_id: model.id(),
1783 token_count: *tokens,
1784 });
1785 cx.notify();
1786 }
1787 }
1788 } else {
1789 let error_message = error
1790 .chain()
1791 .map(|err| err.to_string())
1792 .collect::<Vec<_>>()
1793 .join("\n");
1794 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1795 header: "Error interacting with language model".into(),
1796 message: SharedString::from(error_message.clone()),
1797 }));
1798 }
1799
1800 thread.cancel_last_completion(window, cx);
1801 }
1802 }
1803
1804 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1805
1806 if let Some((request_callback, (request, response_events))) = thread
1807 .request_callback
1808 .as_mut()
1809 .zip(request_callback_parameters.as_ref())
1810 {
1811 request_callback(request, response_events);
1812 }
1813
1814 thread.auto_capture_telemetry(cx);
1815
1816 if let Ok(initial_usage) = initial_token_usage {
1817 let usage = thread.cumulative_token_usage - initial_usage;
1818
1819 telemetry::event!(
1820 "Assistant Thread Completion",
1821 thread_id = thread.id().to_string(),
1822 prompt_id = prompt_id,
1823 model = model.telemetry_id(),
1824 model_provider = model.provider_id().to_string(),
1825 input_tokens = usage.input_tokens,
1826 output_tokens = usage.output_tokens,
1827 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1828 cache_read_input_tokens = usage.cache_read_input_tokens,
1829 );
1830 }
1831 })
1832 .ok();
1833 });
1834
1835 self.pending_completions.push(PendingCompletion {
1836 id: pending_completion_id,
1837 queue_state: QueueState::Sending,
1838 _task: task,
1839 });
1840 }
1841
1842 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1843 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1844 println!("No thread summary model");
1845 return;
1846 };
1847
1848 if !model.provider.is_authenticated(cx) {
1849 return;
1850 }
1851
1852 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1853 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1854 If the conversation is about a specific subject, include it in the title. \
1855 Be descriptive. DO NOT speak in the first person.";
1856
1857 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1858
1859 self.summary = ThreadSummary::Generating;
1860
1861 self.pending_summary = cx.spawn(async move |this, cx| {
1862 let result = async {
1863 let mut messages = model.model.stream_completion(request, &cx).await?;
1864
1865 let mut new_summary = String::new();
1866 while let Some(event) = messages.next().await {
1867 let Ok(event) = event else {
1868 continue;
1869 };
1870 let text = match event {
1871 LanguageModelCompletionEvent::Text(text) => text,
1872 LanguageModelCompletionEvent::StatusUpdate(
1873 CompletionRequestStatus::UsageUpdated { amount, limit },
1874 ) => {
1875 this.update(cx, |thread, _cx| {
1876 thread.last_usage = Some(RequestUsage {
1877 limit,
1878 amount: amount as i32,
1879 });
1880 })?;
1881 continue;
1882 }
1883 _ => continue,
1884 };
1885
1886 let mut lines = text.lines();
1887 new_summary.extend(lines.next());
1888
1889 // Stop if the LLM generated multiple lines.
1890 if lines.next().is_some() {
1891 break;
1892 }
1893 }
1894
1895 anyhow::Ok(new_summary)
1896 }
1897 .await;
1898
1899 this.update(cx, |this, cx| {
1900 match result {
1901 Ok(new_summary) => {
1902 if new_summary.is_empty() {
1903 this.summary = ThreadSummary::Error;
1904 } else {
1905 this.summary = ThreadSummary::Ready(new_summary.into());
1906 }
1907 }
1908 Err(err) => {
1909 this.summary = ThreadSummary::Error;
1910 log::error!("Failed to generate thread summary: {}", err);
1911 }
1912 }
1913 cx.emit(ThreadEvent::SummaryGenerated);
1914 })
1915 .log_err()?;
1916
1917 Some(())
1918 });
1919 }
1920
1921 pub fn start_generating_detailed_summary_if_needed(
1922 &mut self,
1923 thread_store: WeakEntity<ThreadStore>,
1924 cx: &mut Context<Self>,
1925 ) {
1926 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1927 return;
1928 };
1929
1930 match &*self.detailed_summary_rx.borrow() {
1931 DetailedSummaryState::Generating { message_id, .. }
1932 | DetailedSummaryState::Generated { message_id, .. }
1933 if *message_id == last_message_id =>
1934 {
1935 // Already up-to-date
1936 return;
1937 }
1938 _ => {}
1939 }
1940
1941 let Some(ConfiguredModel { model, provider }) =
1942 LanguageModelRegistry::read_global(cx).thread_summary_model()
1943 else {
1944 return;
1945 };
1946
1947 if !provider.is_authenticated(cx) {
1948 return;
1949 }
1950
1951 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1952 1. A brief overview of what was discussed\n\
1953 2. Key facts or information discovered\n\
1954 3. Outcomes or conclusions reached\n\
1955 4. Any action items or next steps if any\n\
1956 Format it in Markdown with headings and bullet points.";
1957
1958 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1959
1960 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1961 message_id: last_message_id,
1962 };
1963
1964 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1965 // be better to allow the old task to complete, but this would require logic for choosing
1966 // which result to prefer (the old task could complete after the new one, resulting in a
1967 // stale summary).
1968 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1969 let stream = model.stream_completion_text(request, &cx);
1970 let Some(mut messages) = stream.await.log_err() else {
1971 thread
1972 .update(cx, |thread, _cx| {
1973 *thread.detailed_summary_tx.borrow_mut() =
1974 DetailedSummaryState::NotGenerated;
1975 })
1976 .ok()?;
1977 return None;
1978 };
1979
1980 let mut new_detailed_summary = String::new();
1981
1982 while let Some(chunk) = messages.stream.next().await {
1983 if let Some(chunk) = chunk.log_err() {
1984 new_detailed_summary.push_str(&chunk);
1985 }
1986 }
1987
1988 thread
1989 .update(cx, |thread, _cx| {
1990 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1991 text: new_detailed_summary.into(),
1992 message_id: last_message_id,
1993 };
1994 })
1995 .ok()?;
1996
1997 // Save thread so its summary can be reused later
1998 if let Some(thread) = thread.upgrade() {
1999 if let Ok(Ok(save_task)) = cx.update(|cx| {
2000 thread_store
2001 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2002 }) {
2003 save_task.await.log_err();
2004 }
2005 }
2006
2007 Some(())
2008 });
2009 }
2010
2011 pub async fn wait_for_detailed_summary_or_text(
2012 this: &Entity<Self>,
2013 cx: &mut AsyncApp,
2014 ) -> Option<SharedString> {
2015 let mut detailed_summary_rx = this
2016 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2017 .ok()?;
2018 loop {
2019 match detailed_summary_rx.recv().await? {
2020 DetailedSummaryState::Generating { .. } => {}
2021 DetailedSummaryState::NotGenerated => {
2022 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2023 }
2024 DetailedSummaryState::Generated { text, .. } => return Some(text),
2025 }
2026 }
2027 }
2028
2029 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2030 self.detailed_summary_rx
2031 .borrow()
2032 .text()
2033 .unwrap_or_else(|| self.text().into())
2034 }
2035
2036 pub fn is_generating_detailed_summary(&self) -> bool {
2037 matches!(
2038 &*self.detailed_summary_rx.borrow(),
2039 DetailedSummaryState::Generating { .. }
2040 )
2041 }
2042
2043 pub fn use_pending_tools(
2044 &mut self,
2045 window: Option<AnyWindowHandle>,
2046 cx: &mut Context<Self>,
2047 model: Arc<dyn LanguageModel>,
2048 ) -> Vec<PendingToolUse> {
2049 self.auto_capture_telemetry(cx);
2050 let request = Arc::new(self.to_completion_request(model.clone(), cx));
2051 let pending_tool_uses = self
2052 .tool_use
2053 .pending_tool_uses()
2054 .into_iter()
2055 .filter(|tool_use| tool_use.status.is_idle())
2056 .cloned()
2057 .collect::<Vec<_>>();
2058
2059 for tool_use in pending_tool_uses.iter() {
2060 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2061 if tool.needs_confirmation(&tool_use.input, cx)
2062 && !AgentSettings::get_global(cx).always_allow_tool_actions
2063 {
2064 self.tool_use.confirm_tool_use(
2065 tool_use.id.clone(),
2066 tool_use.ui_text.clone(),
2067 tool_use.input.clone(),
2068 request.clone(),
2069 tool,
2070 );
2071 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2072 } else {
2073 self.run_tool(
2074 tool_use.id.clone(),
2075 tool_use.ui_text.clone(),
2076 tool_use.input.clone(),
2077 request.clone(),
2078 tool,
2079 model.clone(),
2080 window,
2081 cx,
2082 );
2083 }
2084 } else {
2085 self.handle_hallucinated_tool_use(
2086 tool_use.id.clone(),
2087 tool_use.name.clone(),
2088 window,
2089 cx,
2090 );
2091 }
2092 }
2093
2094 pending_tool_uses
2095 }
2096
2097 pub fn handle_hallucinated_tool_use(
2098 &mut self,
2099 tool_use_id: LanguageModelToolUseId,
2100 hallucinated_tool_name: Arc<str>,
2101 window: Option<AnyWindowHandle>,
2102 cx: &mut Context<Thread>,
2103 ) {
2104 let available_tools = self.tools.read(cx).enabled_tools(cx);
2105
2106 let tool_list = available_tools
2107 .iter()
2108 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2109 .collect::<Vec<_>>()
2110 .join("\n");
2111
2112 let error_message = format!(
2113 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2114 hallucinated_tool_name, tool_list
2115 );
2116
2117 let pending_tool_use = self.tool_use.insert_tool_output(
2118 tool_use_id.clone(),
2119 hallucinated_tool_name,
2120 Err(anyhow!("Missing tool call: {error_message}")),
2121 self.configured_model.as_ref(),
2122 );
2123
2124 cx.emit(ThreadEvent::MissingToolUse {
2125 tool_use_id: tool_use_id.clone(),
2126 ui_text: error_message.into(),
2127 });
2128
2129 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2130 }
2131
2132 pub fn receive_invalid_tool_json(
2133 &mut self,
2134 tool_use_id: LanguageModelToolUseId,
2135 tool_name: Arc<str>,
2136 invalid_json: Arc<str>,
2137 error: String,
2138 window: Option<AnyWindowHandle>,
2139 cx: &mut Context<Thread>,
2140 ) {
2141 log::error!("The model returned invalid input JSON: {invalid_json}");
2142
2143 let pending_tool_use = self.tool_use.insert_tool_output(
2144 tool_use_id.clone(),
2145 tool_name,
2146 Err(anyhow!("Error parsing input JSON: {error}")),
2147 self.configured_model.as_ref(),
2148 );
2149 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2150 pending_tool_use.ui_text.clone()
2151 } else {
2152 log::error!(
2153 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2154 );
2155 format!("Unknown tool {}", tool_use_id).into()
2156 };
2157
2158 cx.emit(ThreadEvent::InvalidToolInput {
2159 tool_use_id: tool_use_id.clone(),
2160 ui_text,
2161 invalid_input_json: invalid_json,
2162 });
2163
2164 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2165 }
2166
2167 pub fn run_tool(
2168 &mut self,
2169 tool_use_id: LanguageModelToolUseId,
2170 ui_text: impl Into<SharedString>,
2171 input: serde_json::Value,
2172 request: Arc<LanguageModelRequest>,
2173 tool: Arc<dyn Tool>,
2174 model: Arc<dyn LanguageModel>,
2175 window: Option<AnyWindowHandle>,
2176 cx: &mut Context<Thread>,
2177 ) {
2178 let task =
2179 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2180 self.tool_use
2181 .run_pending_tool(tool_use_id, ui_text.into(), task);
2182 }
2183
2184 fn spawn_tool_use(
2185 &mut self,
2186 tool_use_id: LanguageModelToolUseId,
2187 request: Arc<LanguageModelRequest>,
2188 input: serde_json::Value,
2189 tool: Arc<dyn Tool>,
2190 model: Arc<dyn LanguageModel>,
2191 window: Option<AnyWindowHandle>,
2192 cx: &mut Context<Thread>,
2193 ) -> Task<()> {
2194 let tool_name: Arc<str> = tool.name().into();
2195
2196 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2197 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2198 } else {
2199 tool.run(
2200 input,
2201 request,
2202 self.project.clone(),
2203 self.action_log.clone(),
2204 model,
2205 window,
2206 cx,
2207 )
2208 };
2209
2210 // Store the card separately if it exists
2211 if let Some(card) = tool_result.card.clone() {
2212 self.tool_use
2213 .insert_tool_result_card(tool_use_id.clone(), card);
2214 }
2215
2216 cx.spawn({
2217 async move |thread: WeakEntity<Thread>, cx| {
2218 let output = tool_result.output.await;
2219
2220 thread
2221 .update(cx, |thread, cx| {
2222 let pending_tool_use = thread.tool_use.insert_tool_output(
2223 tool_use_id.clone(),
2224 tool_name,
2225 output,
2226 thread.configured_model.as_ref(),
2227 );
2228 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2229 })
2230 .ok();
2231 }
2232 })
2233 }
2234
2235 fn tool_finished(
2236 &mut self,
2237 tool_use_id: LanguageModelToolUseId,
2238 pending_tool_use: Option<PendingToolUse>,
2239 canceled: bool,
2240 window: Option<AnyWindowHandle>,
2241 cx: &mut Context<Self>,
2242 ) {
2243 if self.all_tools_finished() {
2244 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2245 if !canceled {
2246 self.send_to_model(model.clone(), window, cx);
2247 }
2248 self.auto_capture_telemetry(cx);
2249 }
2250 }
2251
2252 cx.emit(ThreadEvent::ToolFinished {
2253 tool_use_id,
2254 pending_tool_use,
2255 });
2256 }
2257
2258 /// Cancels the last pending completion, if there are any pending.
2259 ///
2260 /// Returns whether a completion was canceled.
2261 pub fn cancel_last_completion(
2262 &mut self,
2263 window: Option<AnyWindowHandle>,
2264 cx: &mut Context<Self>,
2265 ) -> bool {
2266 let mut canceled = self.pending_completions.pop().is_some();
2267
2268 for pending_tool_use in self.tool_use.cancel_pending() {
2269 canceled = true;
2270 self.tool_finished(
2271 pending_tool_use.id.clone(),
2272 Some(pending_tool_use),
2273 true,
2274 window,
2275 cx,
2276 );
2277 }
2278
2279 if canceled {
2280 cx.emit(ThreadEvent::CompletionCanceled);
2281
2282 // When canceled, we always want to insert the checkpoint.
2283 // (We skip over finalize_pending_checkpoint, because it
2284 // would conclude we didn't have anything to insert here.)
2285 if let Some(checkpoint) = self.pending_checkpoint.take() {
2286 self.insert_checkpoint(checkpoint, cx);
2287 }
2288 } else {
2289 self.finalize_pending_checkpoint(cx);
2290 }
2291
2292 canceled
2293 }
2294
2295 /// Signals that any in-progress editing should be canceled.
2296 ///
2297 /// This method is used to notify listeners (like ActiveThread) that
2298 /// they should cancel any editing operations.
2299 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2300 cx.emit(ThreadEvent::CancelEditing);
2301 }
2302
2303 pub fn feedback(&self) -> Option<ThreadFeedback> {
2304 self.feedback
2305 }
2306
2307 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2308 self.message_feedback.get(&message_id).copied()
2309 }
2310
2311 pub fn report_message_feedback(
2312 &mut self,
2313 message_id: MessageId,
2314 feedback: ThreadFeedback,
2315 cx: &mut Context<Self>,
2316 ) -> Task<Result<()>> {
2317 if self.message_feedback.get(&message_id) == Some(&feedback) {
2318 return Task::ready(Ok(()));
2319 }
2320
2321 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2322 let serialized_thread = self.serialize(cx);
2323 let thread_id = self.id().clone();
2324 let client = self.project.read(cx).client();
2325
2326 let enabled_tool_names: Vec<String> = self
2327 .tools()
2328 .read(cx)
2329 .enabled_tools(cx)
2330 .iter()
2331 .map(|tool| tool.name())
2332 .collect();
2333
2334 self.message_feedback.insert(message_id, feedback);
2335
2336 cx.notify();
2337
2338 let message_content = self
2339 .message(message_id)
2340 .map(|msg| msg.to_string())
2341 .unwrap_or_default();
2342
2343 cx.background_spawn(async move {
2344 let final_project_snapshot = final_project_snapshot.await;
2345 let serialized_thread = serialized_thread.await?;
2346 let thread_data =
2347 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2348
2349 let rating = match feedback {
2350 ThreadFeedback::Positive => "positive",
2351 ThreadFeedback::Negative => "negative",
2352 };
2353 telemetry::event!(
2354 "Assistant Thread Rated",
2355 rating,
2356 thread_id,
2357 enabled_tool_names,
2358 message_id = message_id.0,
2359 message_content,
2360 thread_data,
2361 final_project_snapshot
2362 );
2363 client.telemetry().flush_events().await;
2364
2365 Ok(())
2366 })
2367 }
2368
2369 pub fn report_feedback(
2370 &mut self,
2371 feedback: ThreadFeedback,
2372 cx: &mut Context<Self>,
2373 ) -> Task<Result<()>> {
2374 let last_assistant_message_id = self
2375 .messages
2376 .iter()
2377 .rev()
2378 .find(|msg| msg.role == Role::Assistant)
2379 .map(|msg| msg.id);
2380
2381 if let Some(message_id) = last_assistant_message_id {
2382 self.report_message_feedback(message_id, feedback, cx)
2383 } else {
2384 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2385 let serialized_thread = self.serialize(cx);
2386 let thread_id = self.id().clone();
2387 let client = self.project.read(cx).client();
2388 self.feedback = Some(feedback);
2389 cx.notify();
2390
2391 cx.background_spawn(async move {
2392 let final_project_snapshot = final_project_snapshot.await;
2393 let serialized_thread = serialized_thread.await?;
2394 let thread_data = serde_json::to_value(serialized_thread)
2395 .unwrap_or_else(|_| serde_json::Value::Null);
2396
2397 let rating = match feedback {
2398 ThreadFeedback::Positive => "positive",
2399 ThreadFeedback::Negative => "negative",
2400 };
2401 telemetry::event!(
2402 "Assistant Thread Rated",
2403 rating,
2404 thread_id,
2405 thread_data,
2406 final_project_snapshot
2407 );
2408 client.telemetry().flush_events().await;
2409
2410 Ok(())
2411 })
2412 }
2413 }
2414
2415 /// Create a snapshot of the current project state including git information and unsaved buffers.
2416 fn project_snapshot(
2417 project: Entity<Project>,
2418 cx: &mut Context<Self>,
2419 ) -> Task<Arc<ProjectSnapshot>> {
2420 let git_store = project.read(cx).git_store().clone();
2421 let worktree_snapshots: Vec<_> = project
2422 .read(cx)
2423 .visible_worktrees(cx)
2424 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2425 .collect();
2426
2427 cx.spawn(async move |_, cx| {
2428 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2429
2430 let mut unsaved_buffers = Vec::new();
2431 cx.update(|app_cx| {
2432 let buffer_store = project.read(app_cx).buffer_store();
2433 for buffer_handle in buffer_store.read(app_cx).buffers() {
2434 let buffer = buffer_handle.read(app_cx);
2435 if buffer.is_dirty() {
2436 if let Some(file) = buffer.file() {
2437 let path = file.path().to_string_lossy().to_string();
2438 unsaved_buffers.push(path);
2439 }
2440 }
2441 }
2442 })
2443 .ok();
2444
2445 Arc::new(ProjectSnapshot {
2446 worktree_snapshots,
2447 unsaved_buffer_paths: unsaved_buffers,
2448 timestamp: Utc::now(),
2449 })
2450 })
2451 }
2452
2453 fn worktree_snapshot(
2454 worktree: Entity<project::Worktree>,
2455 git_store: Entity<GitStore>,
2456 cx: &App,
2457 ) -> Task<WorktreeSnapshot> {
2458 cx.spawn(async move |cx| {
2459 // Get worktree path and snapshot
2460 let worktree_info = cx.update(|app_cx| {
2461 let worktree = worktree.read(app_cx);
2462 let path = worktree.abs_path().to_string_lossy().to_string();
2463 let snapshot = worktree.snapshot();
2464 (path, snapshot)
2465 });
2466
2467 let Ok((worktree_path, _snapshot)) = worktree_info else {
2468 return WorktreeSnapshot {
2469 worktree_path: String::new(),
2470 git_state: None,
2471 };
2472 };
2473
2474 let git_state = git_store
2475 .update(cx, |git_store, cx| {
2476 git_store
2477 .repositories()
2478 .values()
2479 .find(|repo| {
2480 repo.read(cx)
2481 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2482 .is_some()
2483 })
2484 .cloned()
2485 })
2486 .ok()
2487 .flatten()
2488 .map(|repo| {
2489 repo.update(cx, |repo, _| {
2490 let current_branch =
2491 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2492 repo.send_job(None, |state, _| async move {
2493 let RepositoryState::Local { backend, .. } = state else {
2494 return GitState {
2495 remote_url: None,
2496 head_sha: None,
2497 current_branch,
2498 diff: None,
2499 };
2500 };
2501
2502 let remote_url = backend.remote_url("origin");
2503 let head_sha = backend.head_sha().await;
2504 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2505
2506 GitState {
2507 remote_url,
2508 head_sha,
2509 current_branch,
2510 diff,
2511 }
2512 })
2513 })
2514 });
2515
2516 let git_state = match git_state {
2517 Some(git_state) => match git_state.ok() {
2518 Some(git_state) => git_state.await.ok(),
2519 None => None,
2520 },
2521 None => None,
2522 };
2523
2524 WorktreeSnapshot {
2525 worktree_path,
2526 git_state,
2527 }
2528 })
2529 }
2530
2531 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2532 let mut markdown = Vec::new();
2533
2534 let summary = self.summary().or_default();
2535 writeln!(markdown, "# {summary}\n")?;
2536
2537 for message in self.messages() {
2538 writeln!(
2539 markdown,
2540 "## {role}\n",
2541 role = match message.role {
2542 Role::User => "User",
2543 Role::Assistant => "Agent",
2544 Role::System => "System",
2545 }
2546 )?;
2547
2548 if !message.loaded_context.text.is_empty() {
2549 writeln!(markdown, "{}", message.loaded_context.text)?;
2550 }
2551
2552 if !message.loaded_context.images.is_empty() {
2553 writeln!(
2554 markdown,
2555 "\n{} images attached as context.\n",
2556 message.loaded_context.images.len()
2557 )?;
2558 }
2559
2560 for segment in &message.segments {
2561 match segment {
2562 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2563 MessageSegment::Thinking { text, .. } => {
2564 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2565 }
2566 MessageSegment::RedactedThinking(_) => {}
2567 }
2568 }
2569
2570 for tool_use in self.tool_uses_for_message(message.id, cx) {
2571 writeln!(
2572 markdown,
2573 "**Use Tool: {} ({})**",
2574 tool_use.name, tool_use.id
2575 )?;
2576 writeln!(markdown, "```json")?;
2577 writeln!(
2578 markdown,
2579 "{}",
2580 serde_json::to_string_pretty(&tool_use.input)?
2581 )?;
2582 writeln!(markdown, "```")?;
2583 }
2584
2585 for tool_result in self.tool_results_for_message(message.id) {
2586 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2587 if tool_result.is_error {
2588 write!(markdown, " (Error)")?;
2589 }
2590
2591 writeln!(markdown, "**\n")?;
2592 match &tool_result.content {
2593 LanguageModelToolResultContent::Text(text) => {
2594 writeln!(markdown, "{text}")?;
2595 }
2596 LanguageModelToolResultContent::Image(image) => {
2597 writeln!(markdown, "", image.source)?;
2598 }
2599 }
2600
2601 if let Some(output) = tool_result.output.as_ref() {
2602 writeln!(
2603 markdown,
2604 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2605 serde_json::to_string_pretty(output)?
2606 )?;
2607 }
2608 }
2609 }
2610
2611 Ok(String::from_utf8_lossy(&markdown).to_string())
2612 }
2613
2614 pub fn keep_edits_in_range(
2615 &mut self,
2616 buffer: Entity<language::Buffer>,
2617 buffer_range: Range<language::Anchor>,
2618 cx: &mut Context<Self>,
2619 ) {
2620 self.action_log.update(cx, |action_log, cx| {
2621 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2622 });
2623 }
2624
2625 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2626 self.action_log
2627 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2628 }
2629
2630 pub fn reject_edits_in_ranges(
2631 &mut self,
2632 buffer: Entity<language::Buffer>,
2633 buffer_ranges: Vec<Range<language::Anchor>>,
2634 cx: &mut Context<Self>,
2635 ) -> Task<Result<()>> {
2636 self.action_log.update(cx, |action_log, cx| {
2637 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2638 })
2639 }
2640
2641 pub fn action_log(&self) -> &Entity<ActionLog> {
2642 &self.action_log
2643 }
2644
2645 pub fn project(&self) -> &Entity<Project> {
2646 &self.project
2647 }
2648
2649 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2650 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2651 return;
2652 }
2653
2654 let now = Instant::now();
2655 if let Some(last) = self.last_auto_capture_at {
2656 if now.duration_since(last).as_secs() < 10 {
2657 return;
2658 }
2659 }
2660
2661 self.last_auto_capture_at = Some(now);
2662
2663 let thread_id = self.id().clone();
2664 let github_login = self
2665 .project
2666 .read(cx)
2667 .user_store()
2668 .read(cx)
2669 .current_user()
2670 .map(|user| user.github_login.clone());
2671 let client = self.project.read(cx).client();
2672 let serialize_task = self.serialize(cx);
2673
2674 cx.background_executor()
2675 .spawn(async move {
2676 if let Ok(serialized_thread) = serialize_task.await {
2677 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2678 telemetry::event!(
2679 "Agent Thread Auto-Captured",
2680 thread_id = thread_id.to_string(),
2681 thread_data = thread_data,
2682 auto_capture_reason = "tracked_user",
2683 github_login = github_login
2684 );
2685
2686 client.telemetry().flush_events().await;
2687 }
2688 }
2689 })
2690 .detach();
2691 }
2692
2693 pub fn cumulative_token_usage(&self) -> TokenUsage {
2694 self.cumulative_token_usage
2695 }
2696
2697 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2698 let Some(model) = self.configured_model.as_ref() else {
2699 return TotalTokenUsage::default();
2700 };
2701
2702 let max = model.model.max_token_count();
2703
2704 let index = self
2705 .messages
2706 .iter()
2707 .position(|msg| msg.id == message_id)
2708 .unwrap_or(0);
2709
2710 if index == 0 {
2711 return TotalTokenUsage { total: 0, max };
2712 }
2713
2714 let token_usage = &self
2715 .request_token_usage
2716 .get(index - 1)
2717 .cloned()
2718 .unwrap_or_default();
2719
2720 TotalTokenUsage {
2721 total: token_usage.total_tokens() as usize,
2722 max,
2723 }
2724 }
2725
2726 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2727 let model = self.configured_model.as_ref()?;
2728
2729 let max = model.model.max_token_count();
2730
2731 if let Some(exceeded_error) = &self.exceeded_window_error {
2732 if model.model.id() == exceeded_error.model_id {
2733 return Some(TotalTokenUsage {
2734 total: exceeded_error.token_count,
2735 max,
2736 });
2737 }
2738 }
2739
2740 let total = self
2741 .token_usage_at_last_message()
2742 .unwrap_or_default()
2743 .total_tokens() as usize;
2744
2745 Some(TotalTokenUsage { total, max })
2746 }
2747
2748 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2749 self.request_token_usage
2750 .get(self.messages.len().saturating_sub(1))
2751 .or_else(|| self.request_token_usage.last())
2752 .cloned()
2753 }
2754
2755 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2756 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2757 self.request_token_usage
2758 .resize(self.messages.len(), placeholder);
2759
2760 if let Some(last) = self.request_token_usage.last_mut() {
2761 *last = token_usage;
2762 }
2763 }
2764
2765 pub fn deny_tool_use(
2766 &mut self,
2767 tool_use_id: LanguageModelToolUseId,
2768 tool_name: Arc<str>,
2769 window: Option<AnyWindowHandle>,
2770 cx: &mut Context<Self>,
2771 ) {
2772 let err = Err(anyhow::anyhow!(
2773 "Permission to run tool action denied by user"
2774 ));
2775
2776 self.tool_use.insert_tool_output(
2777 tool_use_id.clone(),
2778 tool_name,
2779 err,
2780 self.configured_model.as_ref(),
2781 );
2782 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2783 }
2784}
2785
2786#[derive(Debug, Clone, Error)]
2787pub enum ThreadError {
2788 #[error("Payment required")]
2789 PaymentRequired,
2790 #[error("Model request limit reached")]
2791 ModelRequestLimitReached { plan: Plan },
2792 #[error("Message {header}: {message}")]
2793 Message {
2794 header: SharedString,
2795 message: SharedString,
2796 },
2797}
2798
2799#[derive(Debug, Clone)]
2800pub enum ThreadEvent {
2801 ShowError(ThreadError),
2802 StreamedCompletion,
2803 ReceivedTextChunk,
2804 NewRequest,
2805 StreamedAssistantText(MessageId, String),
2806 StreamedAssistantThinking(MessageId, String),
2807 StreamedToolUse {
2808 tool_use_id: LanguageModelToolUseId,
2809 ui_text: Arc<str>,
2810 input: serde_json::Value,
2811 },
2812 MissingToolUse {
2813 tool_use_id: LanguageModelToolUseId,
2814 ui_text: Arc<str>,
2815 },
2816 InvalidToolInput {
2817 tool_use_id: LanguageModelToolUseId,
2818 ui_text: Arc<str>,
2819 invalid_input_json: Arc<str>,
2820 },
2821 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2822 MessageAdded(MessageId),
2823 MessageEdited(MessageId),
2824 MessageDeleted(MessageId),
2825 SummaryGenerated,
2826 SummaryChanged,
2827 UsePendingTools {
2828 tool_uses: Vec<PendingToolUse>,
2829 },
2830 ToolFinished {
2831 #[allow(unused)]
2832 tool_use_id: LanguageModelToolUseId,
2833 /// The pending tool use that corresponds to this tool.
2834 pending_tool_use: Option<PendingToolUse>,
2835 },
2836 CheckpointChanged,
2837 ToolConfirmationNeeded,
2838 CancelEditing,
2839 CompletionCanceled,
2840}
2841
2842impl EventEmitter<ThreadEvent> for Thread {}
2843
2844struct PendingCompletion {
2845 id: usize,
2846 queue_state: QueueState,
2847 _task: Task<()>,
2848}
2849
2850#[cfg(test)]
2851mod tests {
2852 use super::*;
2853 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2854 use agent_settings::{AgentSettings, LanguageModelParameters};
2855 use assistant_tool::ToolRegistry;
2856 use editor::EditorSettings;
2857 use gpui::TestAppContext;
2858 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2859 use project::{FakeFs, Project};
2860 use prompt_store::PromptBuilder;
2861 use serde_json::json;
2862 use settings::{Settings, SettingsStore};
2863 use std::sync::Arc;
2864 use theme::ThemeSettings;
2865 use util::path;
2866 use workspace::Workspace;
2867
2868 #[gpui::test]
2869 async fn test_message_with_context(cx: &mut TestAppContext) {
2870 init_test_settings(cx);
2871
2872 let project = create_test_project(
2873 cx,
2874 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2875 )
2876 .await;
2877
2878 let (_workspace, _thread_store, thread, context_store, model) =
2879 setup_test_environment(cx, project.clone()).await;
2880
2881 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2882 .await
2883 .unwrap();
2884
2885 let context =
2886 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2887 let loaded_context = cx
2888 .update(|cx| load_context(vec![context], &project, &None, cx))
2889 .await;
2890
2891 // Insert user message with context
2892 let message_id = thread.update(cx, |thread, cx| {
2893 thread.insert_user_message(
2894 "Please explain this code",
2895 loaded_context,
2896 None,
2897 Vec::new(),
2898 cx,
2899 )
2900 });
2901
2902 // Check content and context in message object
2903 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2904
2905 // Use different path format strings based on platform for the test
2906 #[cfg(windows)]
2907 let path_part = r"test\code.rs";
2908 #[cfg(not(windows))]
2909 let path_part = "test/code.rs";
2910
2911 let expected_context = format!(
2912 r#"
2913<context>
2914The following items were attached by the user. They are up-to-date and don't need to be re-read.
2915
2916<files>
2917```rs {path_part}
2918fn main() {{
2919 println!("Hello, world!");
2920}}
2921```
2922</files>
2923</context>
2924"#
2925 );
2926
2927 assert_eq!(message.role, Role::User);
2928 assert_eq!(message.segments.len(), 1);
2929 assert_eq!(
2930 message.segments[0],
2931 MessageSegment::Text("Please explain this code".to_string())
2932 );
2933 assert_eq!(message.loaded_context.text, expected_context);
2934
2935 // Check message in request
2936 let request = thread.update(cx, |thread, cx| {
2937 thread.to_completion_request(model.clone(), cx)
2938 });
2939
2940 assert_eq!(request.messages.len(), 2);
2941 let expected_full_message = format!("{}Please explain this code", expected_context);
2942 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2943 }
2944
2945 #[gpui::test]
2946 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2947 init_test_settings(cx);
2948
2949 let project = create_test_project(
2950 cx,
2951 json!({
2952 "file1.rs": "fn function1() {}\n",
2953 "file2.rs": "fn function2() {}\n",
2954 "file3.rs": "fn function3() {}\n",
2955 "file4.rs": "fn function4() {}\n",
2956 }),
2957 )
2958 .await;
2959
2960 let (_, _thread_store, thread, context_store, model) =
2961 setup_test_environment(cx, project.clone()).await;
2962
2963 // First message with context 1
2964 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2965 .await
2966 .unwrap();
2967 let new_contexts = context_store.update(cx, |store, cx| {
2968 store.new_context_for_thread(thread.read(cx), None)
2969 });
2970 assert_eq!(new_contexts.len(), 1);
2971 let loaded_context = cx
2972 .update(|cx| load_context(new_contexts, &project, &None, cx))
2973 .await;
2974 let message1_id = thread.update(cx, |thread, cx| {
2975 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2976 });
2977
2978 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2979 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2980 .await
2981 .unwrap();
2982 let new_contexts = context_store.update(cx, |store, cx| {
2983 store.new_context_for_thread(thread.read(cx), None)
2984 });
2985 assert_eq!(new_contexts.len(), 1);
2986 let loaded_context = cx
2987 .update(|cx| load_context(new_contexts, &project, &None, cx))
2988 .await;
2989 let message2_id = thread.update(cx, |thread, cx| {
2990 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2991 });
2992
2993 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2994 //
2995 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2996 .await
2997 .unwrap();
2998 let new_contexts = context_store.update(cx, |store, cx| {
2999 store.new_context_for_thread(thread.read(cx), None)
3000 });
3001 assert_eq!(new_contexts.len(), 1);
3002 let loaded_context = cx
3003 .update(|cx| load_context(new_contexts, &project, &None, cx))
3004 .await;
3005 let message3_id = thread.update(cx, |thread, cx| {
3006 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3007 });
3008
3009 // Check what contexts are included in each message
3010 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3011 (
3012 thread.message(message1_id).unwrap().clone(),
3013 thread.message(message2_id).unwrap().clone(),
3014 thread.message(message3_id).unwrap().clone(),
3015 )
3016 });
3017
3018 // First message should include context 1
3019 assert!(message1.loaded_context.text.contains("file1.rs"));
3020
3021 // Second message should include only context 2 (not 1)
3022 assert!(!message2.loaded_context.text.contains("file1.rs"));
3023 assert!(message2.loaded_context.text.contains("file2.rs"));
3024
3025 // Third message should include only context 3 (not 1 or 2)
3026 assert!(!message3.loaded_context.text.contains("file1.rs"));
3027 assert!(!message3.loaded_context.text.contains("file2.rs"));
3028 assert!(message3.loaded_context.text.contains("file3.rs"));
3029
3030 // Check entire request to make sure all contexts are properly included
3031 let request = thread.update(cx, |thread, cx| {
3032 thread.to_completion_request(model.clone(), cx)
3033 });
3034
3035 // The request should contain all 3 messages
3036 assert_eq!(request.messages.len(), 4);
3037
3038 // Check that the contexts are properly formatted in each message
3039 assert!(request.messages[1].string_contents().contains("file1.rs"));
3040 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3041 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3042
3043 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3044 assert!(request.messages[2].string_contents().contains("file2.rs"));
3045 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3046
3047 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3048 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3049 assert!(request.messages[3].string_contents().contains("file3.rs"));
3050
3051 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3052 .await
3053 .unwrap();
3054 let new_contexts = context_store.update(cx, |store, cx| {
3055 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3056 });
3057 assert_eq!(new_contexts.len(), 3);
3058 let loaded_context = cx
3059 .update(|cx| load_context(new_contexts, &project, &None, cx))
3060 .await
3061 .loaded_context;
3062
3063 assert!(!loaded_context.text.contains("file1.rs"));
3064 assert!(loaded_context.text.contains("file2.rs"));
3065 assert!(loaded_context.text.contains("file3.rs"));
3066 assert!(loaded_context.text.contains("file4.rs"));
3067
3068 let new_contexts = context_store.update(cx, |store, cx| {
3069 // Remove file4.rs
3070 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3071 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3072 });
3073 assert_eq!(new_contexts.len(), 2);
3074 let loaded_context = cx
3075 .update(|cx| load_context(new_contexts, &project, &None, cx))
3076 .await
3077 .loaded_context;
3078
3079 assert!(!loaded_context.text.contains("file1.rs"));
3080 assert!(loaded_context.text.contains("file2.rs"));
3081 assert!(loaded_context.text.contains("file3.rs"));
3082 assert!(!loaded_context.text.contains("file4.rs"));
3083
3084 let new_contexts = context_store.update(cx, |store, cx| {
3085 // Remove file3.rs
3086 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3087 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3088 });
3089 assert_eq!(new_contexts.len(), 1);
3090 let loaded_context = cx
3091 .update(|cx| load_context(new_contexts, &project, &None, cx))
3092 .await
3093 .loaded_context;
3094
3095 assert!(!loaded_context.text.contains("file1.rs"));
3096 assert!(loaded_context.text.contains("file2.rs"));
3097 assert!(!loaded_context.text.contains("file3.rs"));
3098 assert!(!loaded_context.text.contains("file4.rs"));
3099 }
3100
3101 #[gpui::test]
3102 async fn test_message_without_files(cx: &mut TestAppContext) {
3103 init_test_settings(cx);
3104
3105 let project = create_test_project(
3106 cx,
3107 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3108 )
3109 .await;
3110
3111 let (_, _thread_store, thread, _context_store, model) =
3112 setup_test_environment(cx, project.clone()).await;
3113
3114 // Insert user message without any context (empty context vector)
3115 let message_id = thread.update(cx, |thread, cx| {
3116 thread.insert_user_message(
3117 "What is the best way to learn Rust?",
3118 ContextLoadResult::default(),
3119 None,
3120 Vec::new(),
3121 cx,
3122 )
3123 });
3124
3125 // Check content and context in message object
3126 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3127
3128 // Context should be empty when no files are included
3129 assert_eq!(message.role, Role::User);
3130 assert_eq!(message.segments.len(), 1);
3131 assert_eq!(
3132 message.segments[0],
3133 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3134 );
3135 assert_eq!(message.loaded_context.text, "");
3136
3137 // Check message in request
3138 let request = thread.update(cx, |thread, cx| {
3139 thread.to_completion_request(model.clone(), cx)
3140 });
3141
3142 assert_eq!(request.messages.len(), 2);
3143 assert_eq!(
3144 request.messages[1].string_contents(),
3145 "What is the best way to learn Rust?"
3146 );
3147
3148 // Add second message, also without context
3149 let message2_id = thread.update(cx, |thread, cx| {
3150 thread.insert_user_message(
3151 "Are there any good books?",
3152 ContextLoadResult::default(),
3153 None,
3154 Vec::new(),
3155 cx,
3156 )
3157 });
3158
3159 let message2 =
3160 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3161 assert_eq!(message2.loaded_context.text, "");
3162
3163 // Check that both messages appear in the request
3164 let request = thread.update(cx, |thread, cx| {
3165 thread.to_completion_request(model.clone(), cx)
3166 });
3167
3168 assert_eq!(request.messages.len(), 3);
3169 assert_eq!(
3170 request.messages[1].string_contents(),
3171 "What is the best way to learn Rust?"
3172 );
3173 assert_eq!(
3174 request.messages[2].string_contents(),
3175 "Are there any good books?"
3176 );
3177 }
3178
3179 #[gpui::test]
3180 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3181 init_test_settings(cx);
3182
3183 let project = create_test_project(
3184 cx,
3185 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3186 )
3187 .await;
3188
3189 let (_workspace, _thread_store, thread, context_store, model) =
3190 setup_test_environment(cx, project.clone()).await;
3191
3192 // Open buffer and add it to context
3193 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3194 .await
3195 .unwrap();
3196
3197 let context =
3198 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3199 let loaded_context = cx
3200 .update(|cx| load_context(vec![context], &project, &None, cx))
3201 .await;
3202
3203 // Insert user message with the buffer as context
3204 thread.update(cx, |thread, cx| {
3205 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3206 });
3207
3208 // Create a request and check that it doesn't have a stale buffer warning yet
3209 let initial_request = thread.update(cx, |thread, cx| {
3210 thread.to_completion_request(model.clone(), cx)
3211 });
3212
3213 // Make sure we don't have a stale file warning yet
3214 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3215 msg.string_contents()
3216 .contains("These files changed since last read:")
3217 });
3218 assert!(
3219 !has_stale_warning,
3220 "Should not have stale buffer warning before buffer is modified"
3221 );
3222
3223 // Modify the buffer
3224 buffer.update(cx, |buffer, cx| {
3225 // Find a position at the end of line 1
3226 buffer.edit(
3227 [(1..1, "\n println!(\"Added a new line\");\n")],
3228 None,
3229 cx,
3230 );
3231 });
3232
3233 // Insert another user message without context
3234 thread.update(cx, |thread, cx| {
3235 thread.insert_user_message(
3236 "What does the code do now?",
3237 ContextLoadResult::default(),
3238 None,
3239 Vec::new(),
3240 cx,
3241 )
3242 });
3243
3244 // Create a new request and check for the stale buffer warning
3245 let new_request = thread.update(cx, |thread, cx| {
3246 thread.to_completion_request(model.clone(), cx)
3247 });
3248
3249 // We should have a stale file warning as the last message
3250 let last_message = new_request
3251 .messages
3252 .last()
3253 .expect("Request should have messages");
3254
3255 // The last message should be the stale buffer notification
3256 assert_eq!(last_message.role, Role::User);
3257
3258 // Check the exact content of the message
3259 let expected_content = "These files changed since last read:\n- code.rs\n";
3260 assert_eq!(
3261 last_message.string_contents(),
3262 expected_content,
3263 "Last message should be exactly the stale buffer notification"
3264 );
3265 }
3266
3267 #[gpui::test]
3268 async fn test_temperature_setting(cx: &mut TestAppContext) {
3269 init_test_settings(cx);
3270
3271 let project = create_test_project(
3272 cx,
3273 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3274 )
3275 .await;
3276
3277 let (_workspace, _thread_store, thread, _context_store, model) =
3278 setup_test_environment(cx, project.clone()).await;
3279
3280 // Both model and provider
3281 cx.update(|cx| {
3282 AgentSettings::override_global(
3283 AgentSettings {
3284 model_parameters: vec![LanguageModelParameters {
3285 provider: Some(model.provider_id().0.to_string().into()),
3286 model: Some(model.id().0.clone()),
3287 temperature: Some(0.66),
3288 }],
3289 ..AgentSettings::get_global(cx).clone()
3290 },
3291 cx,
3292 );
3293 });
3294
3295 let request = thread.update(cx, |thread, cx| {
3296 thread.to_completion_request(model.clone(), cx)
3297 });
3298 assert_eq!(request.temperature, Some(0.66));
3299
3300 // Only model
3301 cx.update(|cx| {
3302 AgentSettings::override_global(
3303 AgentSettings {
3304 model_parameters: vec![LanguageModelParameters {
3305 provider: None,
3306 model: Some(model.id().0.clone()),
3307 temperature: Some(0.66),
3308 }],
3309 ..AgentSettings::get_global(cx).clone()
3310 },
3311 cx,
3312 );
3313 });
3314
3315 let request = thread.update(cx, |thread, cx| {
3316 thread.to_completion_request(model.clone(), cx)
3317 });
3318 assert_eq!(request.temperature, Some(0.66));
3319
3320 // Only provider
3321 cx.update(|cx| {
3322 AgentSettings::override_global(
3323 AgentSettings {
3324 model_parameters: vec![LanguageModelParameters {
3325 provider: Some(model.provider_id().0.to_string().into()),
3326 model: None,
3327 temperature: Some(0.66),
3328 }],
3329 ..AgentSettings::get_global(cx).clone()
3330 },
3331 cx,
3332 );
3333 });
3334
3335 let request = thread.update(cx, |thread, cx| {
3336 thread.to_completion_request(model.clone(), cx)
3337 });
3338 assert_eq!(request.temperature, Some(0.66));
3339
3340 // Same model name, different provider
3341 cx.update(|cx| {
3342 AgentSettings::override_global(
3343 AgentSettings {
3344 model_parameters: vec![LanguageModelParameters {
3345 provider: Some("anthropic".into()),
3346 model: Some(model.id().0.clone()),
3347 temperature: Some(0.66),
3348 }],
3349 ..AgentSettings::get_global(cx).clone()
3350 },
3351 cx,
3352 );
3353 });
3354
3355 let request = thread.update(cx, |thread, cx| {
3356 thread.to_completion_request(model.clone(), cx)
3357 });
3358 assert_eq!(request.temperature, None);
3359 }
3360
3361 #[gpui::test]
3362 async fn test_thread_summary(cx: &mut TestAppContext) {
3363 init_test_settings(cx);
3364
3365 let project = create_test_project(cx, json!({})).await;
3366
3367 let (_, _thread_store, thread, _context_store, model) =
3368 setup_test_environment(cx, project.clone()).await;
3369
3370 // Initial state should be pending
3371 thread.read_with(cx, |thread, _| {
3372 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3373 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3374 });
3375
3376 // Manually setting the summary should not be allowed in this state
3377 thread.update(cx, |thread, cx| {
3378 thread.set_summary("This should not work", cx);
3379 });
3380
3381 thread.read_with(cx, |thread, _| {
3382 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3383 });
3384
3385 // Send a message
3386 thread.update(cx, |thread, cx| {
3387 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3388 thread.send_to_model(model.clone(), None, cx);
3389 });
3390
3391 let fake_model = model.as_fake();
3392 simulate_successful_response(&fake_model, cx);
3393
3394 // Should start generating summary when there are >= 2 messages
3395 thread.read_with(cx, |thread, _| {
3396 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3397 });
3398
3399 // Should not be able to set the summary while generating
3400 thread.update(cx, |thread, cx| {
3401 thread.set_summary("This should not work either", cx);
3402 });
3403
3404 thread.read_with(cx, |thread, _| {
3405 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3406 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3407 });
3408
3409 cx.run_until_parked();
3410 fake_model.stream_last_completion_response("Brief");
3411 fake_model.stream_last_completion_response(" Introduction");
3412 fake_model.end_last_completion_stream();
3413 cx.run_until_parked();
3414
3415 // Summary should be set
3416 thread.read_with(cx, |thread, _| {
3417 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3418 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3419 });
3420
3421 // Now we should be able to set a summary
3422 thread.update(cx, |thread, cx| {
3423 thread.set_summary("Brief Intro", cx);
3424 });
3425
3426 thread.read_with(cx, |thread, _| {
3427 assert_eq!(thread.summary().or_default(), "Brief Intro");
3428 });
3429
3430 // Test setting an empty summary (should default to DEFAULT)
3431 thread.update(cx, |thread, cx| {
3432 thread.set_summary("", cx);
3433 });
3434
3435 thread.read_with(cx, |thread, _| {
3436 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3437 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3438 });
3439 }
3440
3441 #[gpui::test]
3442 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3443 init_test_settings(cx);
3444
3445 let project = create_test_project(cx, json!({})).await;
3446
3447 let (_, _thread_store, thread, _context_store, model) =
3448 setup_test_environment(cx, project.clone()).await;
3449
3450 test_summarize_error(&model, &thread, cx);
3451
3452 // Now we should be able to set a summary
3453 thread.update(cx, |thread, cx| {
3454 thread.set_summary("Brief Intro", cx);
3455 });
3456
3457 thread.read_with(cx, |thread, _| {
3458 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3459 assert_eq!(thread.summary().or_default(), "Brief Intro");
3460 });
3461 }
3462
3463 #[gpui::test]
3464 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3465 init_test_settings(cx);
3466
3467 let project = create_test_project(cx, json!({})).await;
3468
3469 let (_, _thread_store, thread, _context_store, model) =
3470 setup_test_environment(cx, project.clone()).await;
3471
3472 test_summarize_error(&model, &thread, cx);
3473
3474 // Sending another message should not trigger another summarize request
3475 thread.update(cx, |thread, cx| {
3476 thread.insert_user_message(
3477 "How are you?",
3478 ContextLoadResult::default(),
3479 None,
3480 vec![],
3481 cx,
3482 );
3483 thread.send_to_model(model.clone(), None, cx);
3484 });
3485
3486 let fake_model = model.as_fake();
3487 simulate_successful_response(&fake_model, cx);
3488
3489 thread.read_with(cx, |thread, _| {
3490 // State is still Error, not Generating
3491 assert!(matches!(thread.summary(), ThreadSummary::Error));
3492 });
3493
3494 // But the summarize request can be invoked manually
3495 thread.update(cx, |thread, cx| {
3496 thread.summarize(cx);
3497 });
3498
3499 thread.read_with(cx, |thread, _| {
3500 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3501 });
3502
3503 cx.run_until_parked();
3504 fake_model.stream_last_completion_response("A successful summary");
3505 fake_model.end_last_completion_stream();
3506 cx.run_until_parked();
3507
3508 thread.read_with(cx, |thread, _| {
3509 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3510 assert_eq!(thread.summary().or_default(), "A successful summary");
3511 });
3512 }
3513
3514 fn test_summarize_error(
3515 model: &Arc<dyn LanguageModel>,
3516 thread: &Entity<Thread>,
3517 cx: &mut TestAppContext,
3518 ) {
3519 thread.update(cx, |thread, cx| {
3520 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3521 thread.send_to_model(model.clone(), None, cx);
3522 });
3523
3524 let fake_model = model.as_fake();
3525 simulate_successful_response(&fake_model, cx);
3526
3527 thread.read_with(cx, |thread, _| {
3528 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3529 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3530 });
3531
3532 // Simulate summary request ending
3533 cx.run_until_parked();
3534 fake_model.end_last_completion_stream();
3535 cx.run_until_parked();
3536
3537 // State is set to Error and default message
3538 thread.read_with(cx, |thread, _| {
3539 assert!(matches!(thread.summary(), ThreadSummary::Error));
3540 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3541 });
3542 }
3543
3544 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3545 cx.run_until_parked();
3546 fake_model.stream_last_completion_response("Assistant response");
3547 fake_model.end_last_completion_stream();
3548 cx.run_until_parked();
3549 }
3550
3551 fn init_test_settings(cx: &mut TestAppContext) {
3552 cx.update(|cx| {
3553 let settings_store = SettingsStore::test(cx);
3554 cx.set_global(settings_store);
3555 language::init(cx);
3556 Project::init_settings(cx);
3557 AgentSettings::register(cx);
3558 prompt_store::init(cx);
3559 thread_store::init(cx);
3560 workspace::init_settings(cx);
3561 language_model::init_settings(cx);
3562 ThemeSettings::register(cx);
3563 EditorSettings::register(cx);
3564 ToolRegistry::default_global(cx);
3565 });
3566 }
3567
3568 // Helper to create a test project with test files
3569 async fn create_test_project(
3570 cx: &mut TestAppContext,
3571 files: serde_json::Value,
3572 ) -> Entity<Project> {
3573 let fs = FakeFs::new(cx.executor());
3574 fs.insert_tree(path!("/test"), files).await;
3575 Project::test(fs, [path!("/test").as_ref()], cx).await
3576 }
3577
3578 async fn setup_test_environment(
3579 cx: &mut TestAppContext,
3580 project: Entity<Project>,
3581 ) -> (
3582 Entity<Workspace>,
3583 Entity<ThreadStore>,
3584 Entity<Thread>,
3585 Entity<ContextStore>,
3586 Arc<dyn LanguageModel>,
3587 ) {
3588 let (workspace, cx) =
3589 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3590
3591 let thread_store = cx
3592 .update(|_, cx| {
3593 ThreadStore::load(
3594 project.clone(),
3595 cx.new(|_| ToolWorkingSet::default()),
3596 None,
3597 Arc::new(PromptBuilder::new(None).unwrap()),
3598 cx,
3599 )
3600 })
3601 .await
3602 .unwrap();
3603
3604 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3605 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3606
3607 let provider = Arc::new(FakeLanguageModelProvider);
3608 let model = provider.test_model();
3609 let model: Arc<dyn LanguageModel> = Arc::new(model);
3610
3611 cx.update(|_, cx| {
3612 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3613 registry.set_default_model(
3614 Some(ConfiguredModel {
3615 provider: provider.clone(),
3616 model: model.clone(),
3617 }),
3618 cx,
3619 );
3620 registry.set_thread_summary_model(
3621 Some(ConfiguredModel {
3622 provider,
3623 model: model.clone(),
3624 }),
3625 cx,
3626 );
3627 })
3628 });
3629
3630 (workspace, thread_store, thread, context_store, model)
3631 }
3632
3633 async fn add_file_to_context(
3634 project: &Entity<Project>,
3635 context_store: &Entity<ContextStore>,
3636 path: &str,
3637 cx: &mut TestAppContext,
3638 ) -> Result<Entity<language::Buffer>> {
3639 let buffer_path = project
3640 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3641 .unwrap();
3642
3643 let buffer = project
3644 .update(cx, |project, cx| {
3645 project.open_buffer(buffer_path.clone(), cx)
3646 })
3647 .await
3648 .unwrap();
3649
3650 context_store.update(cx, |context_store, cx| {
3651 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3652 });
3653
3654 Ok(buffer)
3655 }
3656}