1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|tool_use| tool_use.status.is_error())
875 }
876
877 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
878 self.tool_use.tool_uses_for_message(id, cx)
879 }
880
881 pub fn tool_results_for_message(
882 &self,
883 assistant_message_id: MessageId,
884 ) -> Vec<&LanguageModelToolResult> {
885 self.tool_use.tool_results_for_message(assistant_message_id)
886 }
887
888 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
889 self.tool_use.tool_result(id)
890 }
891
892 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
893 match &self.tool_use.tool_result(id)?.content {
894 LanguageModelToolResultContent::Text(text) => Some(text),
895 LanguageModelToolResultContent::Image(_) => {
896 // TODO: We should display image
897 None
898 }
899 }
900 }
901
902 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
903 self.tool_use.tool_result_card(id).cloned()
904 }
905
906 /// Return tools that are both enabled and supported by the model
907 pub fn available_tools(
908 &self,
909 cx: &App,
910 model: Arc<dyn LanguageModel>,
911 ) -> Vec<LanguageModelRequestTool> {
912 if model.supports_tools() {
913 self.tools()
914 .read(cx)
915 .enabled_tools(cx)
916 .into_iter()
917 .filter_map(|tool| {
918 // Skip tools that cannot be supported
919 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
920 Some(LanguageModelRequestTool {
921 name: tool.name(),
922 description: tool.description(),
923 input_schema,
924 })
925 })
926 .collect()
927 } else {
928 Vec::default()
929 }
930 }
931
932 pub fn insert_user_message(
933 &mut self,
934 text: impl Into<String>,
935 loaded_context: ContextLoadResult,
936 git_checkpoint: Option<GitStoreCheckpoint>,
937 creases: Vec<MessageCrease>,
938 cx: &mut Context<Self>,
939 ) -> MessageId {
940 if !loaded_context.referenced_buffers.is_empty() {
941 self.action_log.update(cx, |log, cx| {
942 for buffer in loaded_context.referenced_buffers {
943 log.buffer_read(buffer, cx);
944 }
945 });
946 }
947
948 let message_id = self.insert_message(
949 Role::User,
950 vec![MessageSegment::Text(text.into())],
951 loaded_context.loaded_context,
952 creases,
953 false,
954 cx,
955 );
956
957 if let Some(git_checkpoint) = git_checkpoint {
958 self.pending_checkpoint = Some(ThreadCheckpoint {
959 message_id,
960 git_checkpoint,
961 });
962 }
963
964 self.auto_capture_telemetry(cx);
965
966 message_id
967 }
968
969 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
970 let id = self.insert_message(
971 Role::User,
972 vec![MessageSegment::Text("Continue where you left off".into())],
973 LoadedContext::default(),
974 vec![],
975 true,
976 cx,
977 );
978 self.pending_checkpoint = None;
979
980 id
981 }
982
983 pub fn insert_assistant_message(
984 &mut self,
985 segments: Vec<MessageSegment>,
986 cx: &mut Context<Self>,
987 ) -> MessageId {
988 self.insert_message(
989 Role::Assistant,
990 segments,
991 LoadedContext::default(),
992 Vec::new(),
993 false,
994 cx,
995 )
996 }
997
998 pub fn insert_message(
999 &mut self,
1000 role: Role,
1001 segments: Vec<MessageSegment>,
1002 loaded_context: LoadedContext,
1003 creases: Vec<MessageCrease>,
1004 is_hidden: bool,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 let id = self.next_message_id.post_inc();
1008 self.messages.push(Message {
1009 id,
1010 role,
1011 segments,
1012 loaded_context,
1013 creases,
1014 is_hidden,
1015 });
1016 self.touch_updated_at();
1017 cx.emit(ThreadEvent::MessageAdded(id));
1018 id
1019 }
1020
1021 pub fn edit_message(
1022 &mut self,
1023 id: MessageId,
1024 new_role: Role,
1025 new_segments: Vec<MessageSegment>,
1026 loaded_context: Option<LoadedContext>,
1027 checkpoint: Option<GitStoreCheckpoint>,
1028 cx: &mut Context<Self>,
1029 ) -> bool {
1030 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1031 return false;
1032 };
1033 message.role = new_role;
1034 message.segments = new_segments;
1035 if let Some(context) = loaded_context {
1036 message.loaded_context = context;
1037 }
1038 if let Some(git_checkpoint) = checkpoint {
1039 self.checkpoints_by_message.insert(
1040 id,
1041 ThreadCheckpoint {
1042 message_id: id,
1043 git_checkpoint,
1044 },
1045 );
1046 }
1047 self.touch_updated_at();
1048 cx.emit(ThreadEvent::MessageEdited(id));
1049 true
1050 }
1051
1052 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1053 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1054 return false;
1055 };
1056 self.messages.remove(index);
1057 self.touch_updated_at();
1058 cx.emit(ThreadEvent::MessageDeleted(id));
1059 true
1060 }
1061
1062 /// Returns the representation of this [`Thread`] in a textual form.
1063 ///
1064 /// This is the representation we use when attaching a thread as context to another thread.
1065 pub fn text(&self) -> String {
1066 let mut text = String::new();
1067
1068 for message in &self.messages {
1069 text.push_str(match message.role {
1070 language_model::Role::User => "User:",
1071 language_model::Role::Assistant => "Agent:",
1072 language_model::Role::System => "System:",
1073 });
1074 text.push('\n');
1075
1076 for segment in &message.segments {
1077 match segment {
1078 MessageSegment::Text(content) => text.push_str(content),
1079 MessageSegment::Thinking { text: content, .. } => {
1080 text.push_str(&format!("<think>{}</think>", content))
1081 }
1082 MessageSegment::RedactedThinking(_) => {}
1083 }
1084 }
1085 text.push('\n');
1086 }
1087
1088 text
1089 }
1090
1091 /// Serializes this thread into a format for storage or telemetry.
1092 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1093 let initial_project_snapshot = self.initial_project_snapshot.clone();
1094 cx.spawn(async move |this, cx| {
1095 let initial_project_snapshot = initial_project_snapshot.await;
1096 this.read_with(cx, |this, cx| SerializedThread {
1097 version: SerializedThread::VERSION.to_string(),
1098 summary: this.summary().or_default(),
1099 updated_at: this.updated_at(),
1100 messages: this
1101 .messages()
1102 .map(|message| SerializedMessage {
1103 id: message.id,
1104 role: message.role,
1105 segments: message
1106 .segments
1107 .iter()
1108 .map(|segment| match segment {
1109 MessageSegment::Text(text) => {
1110 SerializedMessageSegment::Text { text: text.clone() }
1111 }
1112 MessageSegment::Thinking { text, signature } => {
1113 SerializedMessageSegment::Thinking {
1114 text: text.clone(),
1115 signature: signature.clone(),
1116 }
1117 }
1118 MessageSegment::RedactedThinking(data) => {
1119 SerializedMessageSegment::RedactedThinking {
1120 data: data.clone(),
1121 }
1122 }
1123 })
1124 .collect(),
1125 tool_uses: this
1126 .tool_uses_for_message(message.id, cx)
1127 .into_iter()
1128 .map(|tool_use| SerializedToolUse {
1129 id: tool_use.id,
1130 name: tool_use.name,
1131 input: tool_use.input,
1132 })
1133 .collect(),
1134 tool_results: this
1135 .tool_results_for_message(message.id)
1136 .into_iter()
1137 .map(|tool_result| SerializedToolResult {
1138 tool_use_id: tool_result.tool_use_id.clone(),
1139 is_error: tool_result.is_error,
1140 content: tool_result.content.clone(),
1141 output: tool_result.output.clone(),
1142 })
1143 .collect(),
1144 context: message.loaded_context.text.clone(),
1145 creases: message
1146 .creases
1147 .iter()
1148 .map(|crease| SerializedCrease {
1149 start: crease.range.start,
1150 end: crease.range.end,
1151 icon_path: crease.metadata.icon_path.clone(),
1152 label: crease.metadata.label.clone(),
1153 })
1154 .collect(),
1155 is_hidden: message.is_hidden,
1156 })
1157 .collect(),
1158 initial_project_snapshot,
1159 cumulative_token_usage: this.cumulative_token_usage,
1160 request_token_usage: this.request_token_usage.clone(),
1161 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1162 exceeded_window_error: this.exceeded_window_error.clone(),
1163 model: this
1164 .configured_model
1165 .as_ref()
1166 .map(|model| SerializedLanguageModel {
1167 provider: model.provider.id().0.to_string(),
1168 model: model.model.id().0.to_string(),
1169 }),
1170 completion_mode: Some(this.completion_mode),
1171 tool_use_limit_reached: this.tool_use_limit_reached,
1172 })
1173 })
1174 }
1175
1176 pub fn remaining_turns(&self) -> u32 {
1177 self.remaining_turns
1178 }
1179
1180 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1181 self.remaining_turns = remaining_turns;
1182 }
1183
1184 pub fn send_to_model(
1185 &mut self,
1186 model: Arc<dyn LanguageModel>,
1187 intent: CompletionIntent,
1188 window: Option<AnyWindowHandle>,
1189 cx: &mut Context<Self>,
1190 ) {
1191 if self.remaining_turns == 0 {
1192 return;
1193 }
1194
1195 self.remaining_turns -= 1;
1196
1197 let request = self.to_completion_request(model.clone(), intent, cx);
1198
1199 self.stream_completion(request, model, window, cx);
1200 }
1201
1202 pub fn used_tools_since_last_user_message(&self) -> bool {
1203 for message in self.messages.iter().rev() {
1204 if self.tool_use.message_has_tool_results(message.id) {
1205 return true;
1206 } else if message.role == Role::User {
1207 return false;
1208 }
1209 }
1210
1211 false
1212 }
1213
1214 pub fn to_completion_request(
1215 &self,
1216 model: Arc<dyn LanguageModel>,
1217 intent: CompletionIntent,
1218 cx: &mut Context<Self>,
1219 ) -> LanguageModelRequest {
1220 let mut request = LanguageModelRequest {
1221 thread_id: Some(self.id.to_string()),
1222 prompt_id: Some(self.last_prompt_id.to_string()),
1223 intent: Some(intent),
1224 mode: None,
1225 messages: vec![],
1226 tools: Vec::new(),
1227 tool_choice: None,
1228 stop: Vec::new(),
1229 temperature: AgentSettings::temperature_for_model(&model, cx),
1230 };
1231
1232 let available_tools = self.available_tools(cx, model.clone());
1233 let available_tool_names = available_tools
1234 .iter()
1235 .map(|tool| tool.name.clone())
1236 .collect();
1237
1238 let model_context = &ModelContext {
1239 available_tools: available_tool_names,
1240 };
1241
1242 if let Some(project_context) = self.project_context.borrow().as_ref() {
1243 match self
1244 .prompt_builder
1245 .generate_assistant_system_prompt(project_context, model_context)
1246 {
1247 Err(err) => {
1248 let message = format!("{err:?}").into();
1249 log::error!("{message}");
1250 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1251 header: "Error generating system prompt".into(),
1252 message,
1253 }));
1254 }
1255 Ok(system_prompt) => {
1256 request.messages.push(LanguageModelRequestMessage {
1257 role: Role::System,
1258 content: vec![MessageContent::Text(system_prompt)],
1259 cache: true,
1260 });
1261 }
1262 }
1263 } else {
1264 let message = "Context for system prompt unexpectedly not ready.".into();
1265 log::error!("{message}");
1266 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1267 header: "Error generating system prompt".into(),
1268 message,
1269 }));
1270 }
1271
1272 let mut message_ix_to_cache = None;
1273 for message in &self.messages {
1274 let mut request_message = LanguageModelRequestMessage {
1275 role: message.role,
1276 content: Vec::new(),
1277 cache: false,
1278 };
1279
1280 message
1281 .loaded_context
1282 .add_to_request_message(&mut request_message);
1283
1284 for segment in &message.segments {
1285 match segment {
1286 MessageSegment::Text(text) => {
1287 if !text.is_empty() {
1288 request_message
1289 .content
1290 .push(MessageContent::Text(text.into()));
1291 }
1292 }
1293 MessageSegment::Thinking { text, signature } => {
1294 if !text.is_empty() {
1295 request_message.content.push(MessageContent::Thinking {
1296 text: text.into(),
1297 signature: signature.clone(),
1298 });
1299 }
1300 }
1301 MessageSegment::RedactedThinking(data) => {
1302 request_message
1303 .content
1304 .push(MessageContent::RedactedThinking(data.clone()));
1305 }
1306 };
1307 }
1308
1309 let mut cache_message = true;
1310 let mut tool_results_message = LanguageModelRequestMessage {
1311 role: Role::User,
1312 content: Vec::new(),
1313 cache: false,
1314 };
1315 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1316 if let Some(tool_result) = tool_result {
1317 request_message
1318 .content
1319 .push(MessageContent::ToolUse(tool_use.clone()));
1320 tool_results_message
1321 .content
1322 .push(MessageContent::ToolResult(LanguageModelToolResult {
1323 tool_use_id: tool_use.id.clone(),
1324 tool_name: tool_result.tool_name.clone(),
1325 is_error: tool_result.is_error,
1326 content: if tool_result.content.is_empty() {
1327 // Surprisingly, the API fails if we return an empty string here.
1328 // It thinks we are sending a tool use without a tool result.
1329 "<Tool returned an empty string>".into()
1330 } else {
1331 tool_result.content.clone()
1332 },
1333 output: None,
1334 }));
1335 } else {
1336 cache_message = false;
1337 log::debug!(
1338 "skipped tool use {:?} because it is still pending",
1339 tool_use
1340 );
1341 }
1342 }
1343
1344 if cache_message {
1345 message_ix_to_cache = Some(request.messages.len());
1346 }
1347 request.messages.push(request_message);
1348
1349 if !tool_results_message.content.is_empty() {
1350 if cache_message {
1351 message_ix_to_cache = Some(request.messages.len());
1352 }
1353 request.messages.push(tool_results_message);
1354 }
1355 }
1356
1357 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1358 if let Some(message_ix_to_cache) = message_ix_to_cache {
1359 request.messages[message_ix_to_cache].cache = true;
1360 }
1361
1362 self.attached_tracked_files_state(&mut request.messages, cx);
1363
1364 request.tools = available_tools;
1365 request.mode = if model.supports_max_mode() {
1366 Some(self.completion_mode.into())
1367 } else {
1368 Some(CompletionMode::Normal.into())
1369 };
1370
1371 request
1372 }
1373
1374 fn to_summarize_request(
1375 &self,
1376 model: &Arc<dyn LanguageModel>,
1377 intent: CompletionIntent,
1378 added_user_message: String,
1379 cx: &App,
1380 ) -> LanguageModelRequest {
1381 let mut request = LanguageModelRequest {
1382 thread_id: None,
1383 prompt_id: None,
1384 intent: Some(intent),
1385 mode: None,
1386 messages: vec![],
1387 tools: Vec::new(),
1388 tool_choice: None,
1389 stop: Vec::new(),
1390 temperature: AgentSettings::temperature_for_model(model, cx),
1391 };
1392
1393 for message in &self.messages {
1394 let mut request_message = LanguageModelRequestMessage {
1395 role: message.role,
1396 content: Vec::new(),
1397 cache: false,
1398 };
1399
1400 for segment in &message.segments {
1401 match segment {
1402 MessageSegment::Text(text) => request_message
1403 .content
1404 .push(MessageContent::Text(text.clone())),
1405 MessageSegment::Thinking { .. } => {}
1406 MessageSegment::RedactedThinking(_) => {}
1407 }
1408 }
1409
1410 if request_message.content.is_empty() {
1411 continue;
1412 }
1413
1414 request.messages.push(request_message);
1415 }
1416
1417 request.messages.push(LanguageModelRequestMessage {
1418 role: Role::User,
1419 content: vec![MessageContent::Text(added_user_message)],
1420 cache: false,
1421 });
1422
1423 request
1424 }
1425
1426 fn attached_tracked_files_state(
1427 &self,
1428 messages: &mut Vec<LanguageModelRequestMessage>,
1429 cx: &App,
1430 ) {
1431 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1432
1433 let mut stale_message = String::new();
1434
1435 let action_log = self.action_log.read(cx);
1436
1437 for stale_file in action_log.stale_buffers(cx) {
1438 let Some(file) = stale_file.read(cx).file() else {
1439 continue;
1440 };
1441
1442 if stale_message.is_empty() {
1443 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1444 }
1445
1446 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1447 }
1448
1449 let mut content = Vec::with_capacity(2);
1450
1451 if !stale_message.is_empty() {
1452 content.push(stale_message.into());
1453 }
1454
1455 if !content.is_empty() {
1456 let context_message = LanguageModelRequestMessage {
1457 role: Role::User,
1458 content,
1459 cache: false,
1460 };
1461
1462 messages.push(context_message);
1463 }
1464 }
1465
1466 pub fn stream_completion(
1467 &mut self,
1468 request: LanguageModelRequest,
1469 model: Arc<dyn LanguageModel>,
1470 window: Option<AnyWindowHandle>,
1471 cx: &mut Context<Self>,
1472 ) {
1473 self.tool_use_limit_reached = false;
1474
1475 let pending_completion_id = post_inc(&mut self.completion_count);
1476 let mut request_callback_parameters = if self.request_callback.is_some() {
1477 Some((request.clone(), Vec::new()))
1478 } else {
1479 None
1480 };
1481 let prompt_id = self.last_prompt_id.clone();
1482 let tool_use_metadata = ToolUseMetadata {
1483 model: model.clone(),
1484 thread_id: self.id.clone(),
1485 prompt_id: prompt_id.clone(),
1486 };
1487
1488 self.last_received_chunk_at = Some(Instant::now());
1489
1490 let task = cx.spawn(async move |thread, cx| {
1491 let stream_completion_future = model.stream_completion(request, &cx);
1492 let initial_token_usage =
1493 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1494 let stream_completion = async {
1495 let mut events = stream_completion_future.await?;
1496
1497 let mut stop_reason = StopReason::EndTurn;
1498 let mut current_token_usage = TokenUsage::default();
1499
1500 thread
1501 .update(cx, |_thread, cx| {
1502 cx.emit(ThreadEvent::NewRequest);
1503 })
1504 .ok();
1505
1506 let mut request_assistant_message_id = None;
1507
1508 while let Some(event) = events.next().await {
1509 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1510 response_events
1511 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1512 }
1513
1514 thread.update(cx, |thread, cx| {
1515 let event = match event {
1516 Ok(event) => event,
1517 Err(LanguageModelCompletionError::BadInputJson {
1518 id,
1519 tool_name,
1520 raw_input: invalid_input_json,
1521 json_parse_error,
1522 }) => {
1523 thread.receive_invalid_tool_json(
1524 id,
1525 tool_name,
1526 invalid_input_json,
1527 json_parse_error,
1528 window,
1529 cx,
1530 );
1531 return Ok(());
1532 }
1533 Err(LanguageModelCompletionError::Other(error)) => {
1534 return Err(error);
1535 }
1536 };
1537
1538 match event {
1539 LanguageModelCompletionEvent::StartMessage { .. } => {
1540 request_assistant_message_id =
1541 Some(thread.insert_assistant_message(
1542 vec![MessageSegment::Text(String::new())],
1543 cx,
1544 ));
1545 }
1546 LanguageModelCompletionEvent::Stop(reason) => {
1547 stop_reason = reason;
1548 }
1549 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1550 thread.update_token_usage_at_last_message(token_usage);
1551 thread.cumulative_token_usage = thread.cumulative_token_usage
1552 + token_usage
1553 - current_token_usage;
1554 current_token_usage = token_usage;
1555 }
1556 LanguageModelCompletionEvent::Text(chunk) => {
1557 thread.received_chunk();
1558
1559 cx.emit(ThreadEvent::ReceivedTextChunk);
1560 if let Some(last_message) = thread.messages.last_mut() {
1561 if last_message.role == Role::Assistant
1562 && !thread.tool_use.has_tool_results(last_message.id)
1563 {
1564 last_message.push_text(&chunk);
1565 cx.emit(ThreadEvent::StreamedAssistantText(
1566 last_message.id,
1567 chunk,
1568 ));
1569 } else {
1570 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1571 // of a new Assistant response.
1572 //
1573 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1574 // will result in duplicating the text of the chunk in the rendered Markdown.
1575 request_assistant_message_id =
1576 Some(thread.insert_assistant_message(
1577 vec![MessageSegment::Text(chunk.to_string())],
1578 cx,
1579 ));
1580 };
1581 }
1582 }
1583 LanguageModelCompletionEvent::Thinking {
1584 text: chunk,
1585 signature,
1586 } => {
1587 thread.received_chunk();
1588
1589 if let Some(last_message) = thread.messages.last_mut() {
1590 if last_message.role == Role::Assistant
1591 && !thread.tool_use.has_tool_results(last_message.id)
1592 {
1593 last_message.push_thinking(&chunk, signature);
1594 cx.emit(ThreadEvent::StreamedAssistantThinking(
1595 last_message.id,
1596 chunk,
1597 ));
1598 } else {
1599 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1600 // of a new Assistant response.
1601 //
1602 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1603 // will result in duplicating the text of the chunk in the rendered Markdown.
1604 request_assistant_message_id =
1605 Some(thread.insert_assistant_message(
1606 vec![MessageSegment::Thinking {
1607 text: chunk.to_string(),
1608 signature,
1609 }],
1610 cx,
1611 ));
1612 };
1613 }
1614 }
1615 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1616 let last_assistant_message_id = request_assistant_message_id
1617 .unwrap_or_else(|| {
1618 let new_assistant_message_id =
1619 thread.insert_assistant_message(vec![], cx);
1620 request_assistant_message_id =
1621 Some(new_assistant_message_id);
1622 new_assistant_message_id
1623 });
1624
1625 let tool_use_id = tool_use.id.clone();
1626 let streamed_input = if tool_use.is_input_complete {
1627 None
1628 } else {
1629 Some((&tool_use.input).clone())
1630 };
1631
1632 let ui_text = thread.tool_use.request_tool_use(
1633 last_assistant_message_id,
1634 tool_use,
1635 tool_use_metadata.clone(),
1636 cx,
1637 );
1638
1639 if let Some(input) = streamed_input {
1640 cx.emit(ThreadEvent::StreamedToolUse {
1641 tool_use_id,
1642 ui_text,
1643 input,
1644 });
1645 }
1646 }
1647 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1648 if let Some(completion) = thread
1649 .pending_completions
1650 .iter_mut()
1651 .find(|completion| completion.id == pending_completion_id)
1652 {
1653 match status_update {
1654 CompletionRequestStatus::Queued {
1655 position,
1656 } => {
1657 completion.queue_state = QueueState::Queued { position };
1658 }
1659 CompletionRequestStatus::Started => {
1660 completion.queue_state = QueueState::Started;
1661 }
1662 CompletionRequestStatus::Failed {
1663 code, message, request_id
1664 } => {
1665 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1666 }
1667 CompletionRequestStatus::UsageUpdated {
1668 amount, limit
1669 } => {
1670 let usage = RequestUsage { limit, amount: amount as i32 };
1671
1672 thread.last_usage = Some(usage);
1673 }
1674 CompletionRequestStatus::ToolUseLimitReached => {
1675 thread.tool_use_limit_reached = true;
1676 }
1677 }
1678 }
1679 }
1680 }
1681
1682 thread.touch_updated_at();
1683 cx.emit(ThreadEvent::StreamedCompletion);
1684 cx.notify();
1685
1686 thread.auto_capture_telemetry(cx);
1687 Ok(())
1688 })??;
1689
1690 smol::future::yield_now().await;
1691 }
1692
1693 thread.update(cx, |thread, cx| {
1694 thread.last_received_chunk_at = None;
1695 thread
1696 .pending_completions
1697 .retain(|completion| completion.id != pending_completion_id);
1698
1699 // If there is a response without tool use, summarize the message. Otherwise,
1700 // allow two tool uses before summarizing.
1701 if matches!(thread.summary, ThreadSummary::Pending)
1702 && thread.messages.len() >= 2
1703 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1704 {
1705 thread.summarize(cx);
1706 }
1707 })?;
1708
1709 anyhow::Ok(stop_reason)
1710 };
1711
1712 let result = stream_completion.await;
1713
1714 thread
1715 .update(cx, |thread, cx| {
1716 thread.finalize_pending_checkpoint(cx);
1717 match result.as_ref() {
1718 Ok(stop_reason) => match stop_reason {
1719 StopReason::ToolUse => {
1720 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1721 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1722 }
1723 StopReason::EndTurn | StopReason::MaxTokens => {
1724 thread.project.update(cx, |project, cx| {
1725 project.set_agent_location(None, cx);
1726 });
1727 }
1728 StopReason::Refusal => {
1729 thread.project.update(cx, |project, cx| {
1730 project.set_agent_location(None, cx);
1731 });
1732
1733 // Remove the turn that was refused.
1734 //
1735 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1736 {
1737 let mut messages_to_remove = Vec::new();
1738
1739 for (ix, message) in thread.messages.iter().enumerate().rev() {
1740 messages_to_remove.push(message.id);
1741
1742 if message.role == Role::User {
1743 if ix == 0 {
1744 break;
1745 }
1746
1747 if let Some(prev_message) = thread.messages.get(ix - 1) {
1748 if prev_message.role == Role::Assistant {
1749 break;
1750 }
1751 }
1752 }
1753 }
1754
1755 for message_id in messages_to_remove {
1756 thread.delete_message(message_id, cx);
1757 }
1758 }
1759
1760 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1761 header: "Language model refusal".into(),
1762 message: "Model refused to generate content for safety reasons.".into(),
1763 }));
1764 }
1765 },
1766 Err(error) => {
1767 thread.project.update(cx, |project, cx| {
1768 project.set_agent_location(None, cx);
1769 });
1770
1771 if error.is::<PaymentRequiredError>() {
1772 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1773 } else if let Some(error) =
1774 error.downcast_ref::<ModelRequestLimitReachedError>()
1775 {
1776 cx.emit(ThreadEvent::ShowError(
1777 ThreadError::ModelRequestLimitReached { plan: error.plan },
1778 ));
1779 } else if let Some(known_error) =
1780 error.downcast_ref::<LanguageModelKnownError>()
1781 {
1782 match known_error {
1783 LanguageModelKnownError::ContextWindowLimitExceeded {
1784 tokens,
1785 } => {
1786 thread.exceeded_window_error = Some(ExceededWindowError {
1787 model_id: model.id(),
1788 token_count: *tokens,
1789 });
1790 cx.notify();
1791 }
1792 }
1793 } else {
1794 let error_message = error
1795 .chain()
1796 .map(|err| err.to_string())
1797 .collect::<Vec<_>>()
1798 .join("\n");
1799 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1800 header: "Error interacting with language model".into(),
1801 message: SharedString::from(error_message.clone()),
1802 }));
1803 }
1804
1805 thread.cancel_last_completion(window, cx);
1806 }
1807 }
1808
1809 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1810
1811 if let Some((request_callback, (request, response_events))) = thread
1812 .request_callback
1813 .as_mut()
1814 .zip(request_callback_parameters.as_ref())
1815 {
1816 request_callback(request, response_events);
1817 }
1818
1819 thread.auto_capture_telemetry(cx);
1820
1821 if let Ok(initial_usage) = initial_token_usage {
1822 let usage = thread.cumulative_token_usage - initial_usage;
1823
1824 telemetry::event!(
1825 "Assistant Thread Completion",
1826 thread_id = thread.id().to_string(),
1827 prompt_id = prompt_id,
1828 model = model.telemetry_id(),
1829 model_provider = model.provider_id().to_string(),
1830 input_tokens = usage.input_tokens,
1831 output_tokens = usage.output_tokens,
1832 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1833 cache_read_input_tokens = usage.cache_read_input_tokens,
1834 );
1835 }
1836 })
1837 .ok();
1838 });
1839
1840 self.pending_completions.push(PendingCompletion {
1841 id: pending_completion_id,
1842 queue_state: QueueState::Sending,
1843 _task: task,
1844 });
1845 }
1846
1847 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1848 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1849 println!("No thread summary model");
1850 return;
1851 };
1852
1853 if !model.provider.is_authenticated(cx) {
1854 return;
1855 }
1856
1857 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1858 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1859 If the conversation is about a specific subject, include it in the title. \
1860 Be descriptive. DO NOT speak in the first person.";
1861
1862 let request = self.to_summarize_request(
1863 &model.model,
1864 CompletionIntent::ThreadSummarization,
1865 added_user_message.into(),
1866 cx,
1867 );
1868
1869 self.summary = ThreadSummary::Generating;
1870
1871 self.pending_summary = cx.spawn(async move |this, cx| {
1872 let result = async {
1873 let mut messages = model.model.stream_completion(request, &cx).await?;
1874
1875 let mut new_summary = String::new();
1876 while let Some(event) = messages.next().await {
1877 let Ok(event) = event else {
1878 continue;
1879 };
1880 let text = match event {
1881 LanguageModelCompletionEvent::Text(text) => text,
1882 LanguageModelCompletionEvent::StatusUpdate(
1883 CompletionRequestStatus::UsageUpdated { amount, limit },
1884 ) => {
1885 this.update(cx, |thread, _cx| {
1886 thread.last_usage = Some(RequestUsage {
1887 limit,
1888 amount: amount as i32,
1889 });
1890 })?;
1891 continue;
1892 }
1893 _ => continue,
1894 };
1895
1896 let mut lines = text.lines();
1897 new_summary.extend(lines.next());
1898
1899 // Stop if the LLM generated multiple lines.
1900 if lines.next().is_some() {
1901 break;
1902 }
1903 }
1904
1905 anyhow::Ok(new_summary)
1906 }
1907 .await;
1908
1909 this.update(cx, |this, cx| {
1910 match result {
1911 Ok(new_summary) => {
1912 if new_summary.is_empty() {
1913 this.summary = ThreadSummary::Error;
1914 } else {
1915 this.summary = ThreadSummary::Ready(new_summary.into());
1916 }
1917 }
1918 Err(err) => {
1919 this.summary = ThreadSummary::Error;
1920 log::error!("Failed to generate thread summary: {}", err);
1921 }
1922 }
1923 cx.emit(ThreadEvent::SummaryGenerated);
1924 })
1925 .log_err()?;
1926
1927 Some(())
1928 });
1929 }
1930
1931 pub fn start_generating_detailed_summary_if_needed(
1932 &mut self,
1933 thread_store: WeakEntity<ThreadStore>,
1934 cx: &mut Context<Self>,
1935 ) {
1936 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1937 return;
1938 };
1939
1940 match &*self.detailed_summary_rx.borrow() {
1941 DetailedSummaryState::Generating { message_id, .. }
1942 | DetailedSummaryState::Generated { message_id, .. }
1943 if *message_id == last_message_id =>
1944 {
1945 // Already up-to-date
1946 return;
1947 }
1948 _ => {}
1949 }
1950
1951 let Some(ConfiguredModel { model, provider }) =
1952 LanguageModelRegistry::read_global(cx).thread_summary_model()
1953 else {
1954 return;
1955 };
1956
1957 if !provider.is_authenticated(cx) {
1958 return;
1959 }
1960
1961 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1962 1. A brief overview of what was discussed\n\
1963 2. Key facts or information discovered\n\
1964 3. Outcomes or conclusions reached\n\
1965 4. Any action items or next steps if any\n\
1966 Format it in Markdown with headings and bullet points.";
1967
1968 let request = self.to_summarize_request(
1969 &model,
1970 CompletionIntent::ThreadContextSummarization,
1971 added_user_message.into(),
1972 cx,
1973 );
1974
1975 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1976 message_id: last_message_id,
1977 };
1978
1979 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1980 // be better to allow the old task to complete, but this would require logic for choosing
1981 // which result to prefer (the old task could complete after the new one, resulting in a
1982 // stale summary).
1983 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1984 let stream = model.stream_completion_text(request, &cx);
1985 let Some(mut messages) = stream.await.log_err() else {
1986 thread
1987 .update(cx, |thread, _cx| {
1988 *thread.detailed_summary_tx.borrow_mut() =
1989 DetailedSummaryState::NotGenerated;
1990 })
1991 .ok()?;
1992 return None;
1993 };
1994
1995 let mut new_detailed_summary = String::new();
1996
1997 while let Some(chunk) = messages.stream.next().await {
1998 if let Some(chunk) = chunk.log_err() {
1999 new_detailed_summary.push_str(&chunk);
2000 }
2001 }
2002
2003 thread
2004 .update(cx, |thread, _cx| {
2005 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2006 text: new_detailed_summary.into(),
2007 message_id: last_message_id,
2008 };
2009 })
2010 .ok()?;
2011
2012 // Save thread so its summary can be reused later
2013 if let Some(thread) = thread.upgrade() {
2014 if let Ok(Ok(save_task)) = cx.update(|cx| {
2015 thread_store
2016 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2017 }) {
2018 save_task.await.log_err();
2019 }
2020 }
2021
2022 Some(())
2023 });
2024 }
2025
2026 pub async fn wait_for_detailed_summary_or_text(
2027 this: &Entity<Self>,
2028 cx: &mut AsyncApp,
2029 ) -> Option<SharedString> {
2030 let mut detailed_summary_rx = this
2031 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2032 .ok()?;
2033 loop {
2034 match detailed_summary_rx.recv().await? {
2035 DetailedSummaryState::Generating { .. } => {}
2036 DetailedSummaryState::NotGenerated => {
2037 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2038 }
2039 DetailedSummaryState::Generated { text, .. } => return Some(text),
2040 }
2041 }
2042 }
2043
2044 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2045 self.detailed_summary_rx
2046 .borrow()
2047 .text()
2048 .unwrap_or_else(|| self.text().into())
2049 }
2050
2051 pub fn is_generating_detailed_summary(&self) -> bool {
2052 matches!(
2053 &*self.detailed_summary_rx.borrow(),
2054 DetailedSummaryState::Generating { .. }
2055 )
2056 }
2057
2058 pub fn use_pending_tools(
2059 &mut self,
2060 window: Option<AnyWindowHandle>,
2061 cx: &mut Context<Self>,
2062 model: Arc<dyn LanguageModel>,
2063 ) -> Vec<PendingToolUse> {
2064 self.auto_capture_telemetry(cx);
2065 let request =
2066 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2067 let pending_tool_uses = self
2068 .tool_use
2069 .pending_tool_uses()
2070 .into_iter()
2071 .filter(|tool_use| tool_use.status.is_idle())
2072 .cloned()
2073 .collect::<Vec<_>>();
2074
2075 for tool_use in pending_tool_uses.iter() {
2076 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2077 if tool.needs_confirmation(&tool_use.input, cx)
2078 && !AgentSettings::get_global(cx).always_allow_tool_actions
2079 {
2080 self.tool_use.confirm_tool_use(
2081 tool_use.id.clone(),
2082 tool_use.ui_text.clone(),
2083 tool_use.input.clone(),
2084 request.clone(),
2085 tool,
2086 );
2087 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2088 } else {
2089 self.run_tool(
2090 tool_use.id.clone(),
2091 tool_use.ui_text.clone(),
2092 tool_use.input.clone(),
2093 request.clone(),
2094 tool,
2095 model.clone(),
2096 window,
2097 cx,
2098 );
2099 }
2100 } else {
2101 self.handle_hallucinated_tool_use(
2102 tool_use.id.clone(),
2103 tool_use.name.clone(),
2104 window,
2105 cx,
2106 );
2107 }
2108 }
2109
2110 pending_tool_uses
2111 }
2112
2113 pub fn handle_hallucinated_tool_use(
2114 &mut self,
2115 tool_use_id: LanguageModelToolUseId,
2116 hallucinated_tool_name: Arc<str>,
2117 window: Option<AnyWindowHandle>,
2118 cx: &mut Context<Thread>,
2119 ) {
2120 let available_tools = self.tools.read(cx).enabled_tools(cx);
2121
2122 let tool_list = available_tools
2123 .iter()
2124 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2125 .collect::<Vec<_>>()
2126 .join("\n");
2127
2128 let error_message = format!(
2129 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2130 hallucinated_tool_name, tool_list
2131 );
2132
2133 let pending_tool_use = self.tool_use.insert_tool_output(
2134 tool_use_id.clone(),
2135 hallucinated_tool_name,
2136 Err(anyhow!("Missing tool call: {error_message}")),
2137 self.configured_model.as_ref(),
2138 );
2139
2140 cx.emit(ThreadEvent::MissingToolUse {
2141 tool_use_id: tool_use_id.clone(),
2142 ui_text: error_message.into(),
2143 });
2144
2145 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2146 }
2147
2148 pub fn receive_invalid_tool_json(
2149 &mut self,
2150 tool_use_id: LanguageModelToolUseId,
2151 tool_name: Arc<str>,
2152 invalid_json: Arc<str>,
2153 error: String,
2154 window: Option<AnyWindowHandle>,
2155 cx: &mut Context<Thread>,
2156 ) {
2157 log::error!("The model returned invalid input JSON: {invalid_json}");
2158
2159 let pending_tool_use = self.tool_use.insert_tool_output(
2160 tool_use_id.clone(),
2161 tool_name,
2162 Err(anyhow!("Error parsing input JSON: {error}")),
2163 self.configured_model.as_ref(),
2164 );
2165 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2166 pending_tool_use.ui_text.clone()
2167 } else {
2168 log::error!(
2169 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2170 );
2171 format!("Unknown tool {}", tool_use_id).into()
2172 };
2173
2174 cx.emit(ThreadEvent::InvalidToolInput {
2175 tool_use_id: tool_use_id.clone(),
2176 ui_text,
2177 invalid_input_json: invalid_json,
2178 });
2179
2180 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2181 }
2182
2183 pub fn run_tool(
2184 &mut self,
2185 tool_use_id: LanguageModelToolUseId,
2186 ui_text: impl Into<SharedString>,
2187 input: serde_json::Value,
2188 request: Arc<LanguageModelRequest>,
2189 tool: Arc<dyn Tool>,
2190 model: Arc<dyn LanguageModel>,
2191 window: Option<AnyWindowHandle>,
2192 cx: &mut Context<Thread>,
2193 ) {
2194 let task =
2195 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2196 self.tool_use
2197 .run_pending_tool(tool_use_id, ui_text.into(), task);
2198 }
2199
2200 fn spawn_tool_use(
2201 &mut self,
2202 tool_use_id: LanguageModelToolUseId,
2203 request: Arc<LanguageModelRequest>,
2204 input: serde_json::Value,
2205 tool: Arc<dyn Tool>,
2206 model: Arc<dyn LanguageModel>,
2207 window: Option<AnyWindowHandle>,
2208 cx: &mut Context<Thread>,
2209 ) -> Task<()> {
2210 let tool_name: Arc<str> = tool.name().into();
2211
2212 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2213 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2214 } else {
2215 tool.run(
2216 input,
2217 request,
2218 self.project.clone(),
2219 self.action_log.clone(),
2220 model,
2221 window,
2222 cx,
2223 )
2224 };
2225
2226 // Store the card separately if it exists
2227 if let Some(card) = tool_result.card.clone() {
2228 self.tool_use
2229 .insert_tool_result_card(tool_use_id.clone(), card);
2230 }
2231
2232 cx.spawn({
2233 async move |thread: WeakEntity<Thread>, cx| {
2234 let output = tool_result.output.await;
2235
2236 thread
2237 .update(cx, |thread, cx| {
2238 let pending_tool_use = thread.tool_use.insert_tool_output(
2239 tool_use_id.clone(),
2240 tool_name,
2241 output,
2242 thread.configured_model.as_ref(),
2243 );
2244 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2245 })
2246 .ok();
2247 }
2248 })
2249 }
2250
2251 fn tool_finished(
2252 &mut self,
2253 tool_use_id: LanguageModelToolUseId,
2254 pending_tool_use: Option<PendingToolUse>,
2255 canceled: bool,
2256 window: Option<AnyWindowHandle>,
2257 cx: &mut Context<Self>,
2258 ) {
2259 if self.all_tools_finished() {
2260 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2261 if !canceled {
2262 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2263 }
2264 self.auto_capture_telemetry(cx);
2265 }
2266 }
2267
2268 cx.emit(ThreadEvent::ToolFinished {
2269 tool_use_id,
2270 pending_tool_use,
2271 });
2272 }
2273
2274 /// Cancels the last pending completion, if there are any pending.
2275 ///
2276 /// Returns whether a completion was canceled.
2277 pub fn cancel_last_completion(
2278 &mut self,
2279 window: Option<AnyWindowHandle>,
2280 cx: &mut Context<Self>,
2281 ) -> bool {
2282 let mut canceled = self.pending_completions.pop().is_some();
2283
2284 for pending_tool_use in self.tool_use.cancel_pending() {
2285 canceled = true;
2286 self.tool_finished(
2287 pending_tool_use.id.clone(),
2288 Some(pending_tool_use),
2289 true,
2290 window,
2291 cx,
2292 );
2293 }
2294
2295 if canceled {
2296 cx.emit(ThreadEvent::CompletionCanceled);
2297
2298 // When canceled, we always want to insert the checkpoint.
2299 // (We skip over finalize_pending_checkpoint, because it
2300 // would conclude we didn't have anything to insert here.)
2301 if let Some(checkpoint) = self.pending_checkpoint.take() {
2302 self.insert_checkpoint(checkpoint, cx);
2303 }
2304 } else {
2305 self.finalize_pending_checkpoint(cx);
2306 }
2307
2308 canceled
2309 }
2310
2311 /// Signals that any in-progress editing should be canceled.
2312 ///
2313 /// This method is used to notify listeners (like ActiveThread) that
2314 /// they should cancel any editing operations.
2315 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2316 cx.emit(ThreadEvent::CancelEditing);
2317 }
2318
2319 pub fn feedback(&self) -> Option<ThreadFeedback> {
2320 self.feedback
2321 }
2322
2323 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2324 self.message_feedback.get(&message_id).copied()
2325 }
2326
2327 pub fn report_message_feedback(
2328 &mut self,
2329 message_id: MessageId,
2330 feedback: ThreadFeedback,
2331 cx: &mut Context<Self>,
2332 ) -> Task<Result<()>> {
2333 if self.message_feedback.get(&message_id) == Some(&feedback) {
2334 return Task::ready(Ok(()));
2335 }
2336
2337 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2338 let serialized_thread = self.serialize(cx);
2339 let thread_id = self.id().clone();
2340 let client = self.project.read(cx).client();
2341
2342 let enabled_tool_names: Vec<String> = self
2343 .tools()
2344 .read(cx)
2345 .enabled_tools(cx)
2346 .iter()
2347 .map(|tool| tool.name())
2348 .collect();
2349
2350 self.message_feedback.insert(message_id, feedback);
2351
2352 cx.notify();
2353
2354 let message_content = self
2355 .message(message_id)
2356 .map(|msg| msg.to_string())
2357 .unwrap_or_default();
2358
2359 cx.background_spawn(async move {
2360 let final_project_snapshot = final_project_snapshot.await;
2361 let serialized_thread = serialized_thread.await?;
2362 let thread_data =
2363 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2364
2365 let rating = match feedback {
2366 ThreadFeedback::Positive => "positive",
2367 ThreadFeedback::Negative => "negative",
2368 };
2369 telemetry::event!(
2370 "Assistant Thread Rated",
2371 rating,
2372 thread_id,
2373 enabled_tool_names,
2374 message_id = message_id.0,
2375 message_content,
2376 thread_data,
2377 final_project_snapshot
2378 );
2379 client.telemetry().flush_events().await;
2380
2381 Ok(())
2382 })
2383 }
2384
2385 pub fn report_feedback(
2386 &mut self,
2387 feedback: ThreadFeedback,
2388 cx: &mut Context<Self>,
2389 ) -> Task<Result<()>> {
2390 let last_assistant_message_id = self
2391 .messages
2392 .iter()
2393 .rev()
2394 .find(|msg| msg.role == Role::Assistant)
2395 .map(|msg| msg.id);
2396
2397 if let Some(message_id) = last_assistant_message_id {
2398 self.report_message_feedback(message_id, feedback, cx)
2399 } else {
2400 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2401 let serialized_thread = self.serialize(cx);
2402 let thread_id = self.id().clone();
2403 let client = self.project.read(cx).client();
2404 self.feedback = Some(feedback);
2405 cx.notify();
2406
2407 cx.background_spawn(async move {
2408 let final_project_snapshot = final_project_snapshot.await;
2409 let serialized_thread = serialized_thread.await?;
2410 let thread_data = serde_json::to_value(serialized_thread)
2411 .unwrap_or_else(|_| serde_json::Value::Null);
2412
2413 let rating = match feedback {
2414 ThreadFeedback::Positive => "positive",
2415 ThreadFeedback::Negative => "negative",
2416 };
2417 telemetry::event!(
2418 "Assistant Thread Rated",
2419 rating,
2420 thread_id,
2421 thread_data,
2422 final_project_snapshot
2423 );
2424 client.telemetry().flush_events().await;
2425
2426 Ok(())
2427 })
2428 }
2429 }
2430
2431 /// Create a snapshot of the current project state including git information and unsaved buffers.
2432 fn project_snapshot(
2433 project: Entity<Project>,
2434 cx: &mut Context<Self>,
2435 ) -> Task<Arc<ProjectSnapshot>> {
2436 let git_store = project.read(cx).git_store().clone();
2437 let worktree_snapshots: Vec<_> = project
2438 .read(cx)
2439 .visible_worktrees(cx)
2440 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2441 .collect();
2442
2443 cx.spawn(async move |_, cx| {
2444 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2445
2446 let mut unsaved_buffers = Vec::new();
2447 cx.update(|app_cx| {
2448 let buffer_store = project.read(app_cx).buffer_store();
2449 for buffer_handle in buffer_store.read(app_cx).buffers() {
2450 let buffer = buffer_handle.read(app_cx);
2451 if buffer.is_dirty() {
2452 if let Some(file) = buffer.file() {
2453 let path = file.path().to_string_lossy().to_string();
2454 unsaved_buffers.push(path);
2455 }
2456 }
2457 }
2458 })
2459 .ok();
2460
2461 Arc::new(ProjectSnapshot {
2462 worktree_snapshots,
2463 unsaved_buffer_paths: unsaved_buffers,
2464 timestamp: Utc::now(),
2465 })
2466 })
2467 }
2468
2469 fn worktree_snapshot(
2470 worktree: Entity<project::Worktree>,
2471 git_store: Entity<GitStore>,
2472 cx: &App,
2473 ) -> Task<WorktreeSnapshot> {
2474 cx.spawn(async move |cx| {
2475 // Get worktree path and snapshot
2476 let worktree_info = cx.update(|app_cx| {
2477 let worktree = worktree.read(app_cx);
2478 let path = worktree.abs_path().to_string_lossy().to_string();
2479 let snapshot = worktree.snapshot();
2480 (path, snapshot)
2481 });
2482
2483 let Ok((worktree_path, _snapshot)) = worktree_info else {
2484 return WorktreeSnapshot {
2485 worktree_path: String::new(),
2486 git_state: None,
2487 };
2488 };
2489
2490 let git_state = git_store
2491 .update(cx, |git_store, cx| {
2492 git_store
2493 .repositories()
2494 .values()
2495 .find(|repo| {
2496 repo.read(cx)
2497 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2498 .is_some()
2499 })
2500 .cloned()
2501 })
2502 .ok()
2503 .flatten()
2504 .map(|repo| {
2505 repo.update(cx, |repo, _| {
2506 let current_branch =
2507 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2508 repo.send_job(None, |state, _| async move {
2509 let RepositoryState::Local { backend, .. } = state else {
2510 return GitState {
2511 remote_url: None,
2512 head_sha: None,
2513 current_branch,
2514 diff: None,
2515 };
2516 };
2517
2518 let remote_url = backend.remote_url("origin");
2519 let head_sha = backend.head_sha().await;
2520 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2521
2522 GitState {
2523 remote_url,
2524 head_sha,
2525 current_branch,
2526 diff,
2527 }
2528 })
2529 })
2530 });
2531
2532 let git_state = match git_state {
2533 Some(git_state) => match git_state.ok() {
2534 Some(git_state) => git_state.await.ok(),
2535 None => None,
2536 },
2537 None => None,
2538 };
2539
2540 WorktreeSnapshot {
2541 worktree_path,
2542 git_state,
2543 }
2544 })
2545 }
2546
2547 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2548 let mut markdown = Vec::new();
2549
2550 let summary = self.summary().or_default();
2551 writeln!(markdown, "# {summary}\n")?;
2552
2553 for message in self.messages() {
2554 writeln!(
2555 markdown,
2556 "## {role}\n",
2557 role = match message.role {
2558 Role::User => "User",
2559 Role::Assistant => "Agent",
2560 Role::System => "System",
2561 }
2562 )?;
2563
2564 if !message.loaded_context.text.is_empty() {
2565 writeln!(markdown, "{}", message.loaded_context.text)?;
2566 }
2567
2568 if !message.loaded_context.images.is_empty() {
2569 writeln!(
2570 markdown,
2571 "\n{} images attached as context.\n",
2572 message.loaded_context.images.len()
2573 )?;
2574 }
2575
2576 for segment in &message.segments {
2577 match segment {
2578 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2579 MessageSegment::Thinking { text, .. } => {
2580 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2581 }
2582 MessageSegment::RedactedThinking(_) => {}
2583 }
2584 }
2585
2586 for tool_use in self.tool_uses_for_message(message.id, cx) {
2587 writeln!(
2588 markdown,
2589 "**Use Tool: {} ({})**",
2590 tool_use.name, tool_use.id
2591 )?;
2592 writeln!(markdown, "```json")?;
2593 writeln!(
2594 markdown,
2595 "{}",
2596 serde_json::to_string_pretty(&tool_use.input)?
2597 )?;
2598 writeln!(markdown, "```")?;
2599 }
2600
2601 for tool_result in self.tool_results_for_message(message.id) {
2602 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2603 if tool_result.is_error {
2604 write!(markdown, " (Error)")?;
2605 }
2606
2607 writeln!(markdown, "**\n")?;
2608 match &tool_result.content {
2609 LanguageModelToolResultContent::Text(text) => {
2610 writeln!(markdown, "{text}")?;
2611 }
2612 LanguageModelToolResultContent::Image(image) => {
2613 writeln!(markdown, "", image.source)?;
2614 }
2615 }
2616
2617 if let Some(output) = tool_result.output.as_ref() {
2618 writeln!(
2619 markdown,
2620 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2621 serde_json::to_string_pretty(output)?
2622 )?;
2623 }
2624 }
2625 }
2626
2627 Ok(String::from_utf8_lossy(&markdown).to_string())
2628 }
2629
2630 pub fn keep_edits_in_range(
2631 &mut self,
2632 buffer: Entity<language::Buffer>,
2633 buffer_range: Range<language::Anchor>,
2634 cx: &mut Context<Self>,
2635 ) {
2636 self.action_log.update(cx, |action_log, cx| {
2637 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2638 });
2639 }
2640
2641 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2642 self.action_log
2643 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2644 }
2645
2646 pub fn reject_edits_in_ranges(
2647 &mut self,
2648 buffer: Entity<language::Buffer>,
2649 buffer_ranges: Vec<Range<language::Anchor>>,
2650 cx: &mut Context<Self>,
2651 ) -> Task<Result<()>> {
2652 self.action_log.update(cx, |action_log, cx| {
2653 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2654 })
2655 }
2656
2657 pub fn action_log(&self) -> &Entity<ActionLog> {
2658 &self.action_log
2659 }
2660
2661 pub fn project(&self) -> &Entity<Project> {
2662 &self.project
2663 }
2664
2665 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2666 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2667 return;
2668 }
2669
2670 let now = Instant::now();
2671 if let Some(last) = self.last_auto_capture_at {
2672 if now.duration_since(last).as_secs() < 10 {
2673 return;
2674 }
2675 }
2676
2677 self.last_auto_capture_at = Some(now);
2678
2679 let thread_id = self.id().clone();
2680 let github_login = self
2681 .project
2682 .read(cx)
2683 .user_store()
2684 .read(cx)
2685 .current_user()
2686 .map(|user| user.github_login.clone());
2687 let client = self.project.read(cx).client();
2688 let serialize_task = self.serialize(cx);
2689
2690 cx.background_executor()
2691 .spawn(async move {
2692 if let Ok(serialized_thread) = serialize_task.await {
2693 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2694 telemetry::event!(
2695 "Agent Thread Auto-Captured",
2696 thread_id = thread_id.to_string(),
2697 thread_data = thread_data,
2698 auto_capture_reason = "tracked_user",
2699 github_login = github_login
2700 );
2701
2702 client.telemetry().flush_events().await;
2703 }
2704 }
2705 })
2706 .detach();
2707 }
2708
2709 pub fn cumulative_token_usage(&self) -> TokenUsage {
2710 self.cumulative_token_usage
2711 }
2712
2713 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2714 let Some(model) = self.configured_model.as_ref() else {
2715 return TotalTokenUsage::default();
2716 };
2717
2718 let max = model.model.max_token_count();
2719
2720 let index = self
2721 .messages
2722 .iter()
2723 .position(|msg| msg.id == message_id)
2724 .unwrap_or(0);
2725
2726 if index == 0 {
2727 return TotalTokenUsage { total: 0, max };
2728 }
2729
2730 let token_usage = &self
2731 .request_token_usage
2732 .get(index - 1)
2733 .cloned()
2734 .unwrap_or_default();
2735
2736 TotalTokenUsage {
2737 total: token_usage.total_tokens() as usize,
2738 max,
2739 }
2740 }
2741
2742 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2743 let model = self.configured_model.as_ref()?;
2744
2745 let max = model.model.max_token_count();
2746
2747 if let Some(exceeded_error) = &self.exceeded_window_error {
2748 if model.model.id() == exceeded_error.model_id {
2749 return Some(TotalTokenUsage {
2750 total: exceeded_error.token_count,
2751 max,
2752 });
2753 }
2754 }
2755
2756 let total = self
2757 .token_usage_at_last_message()
2758 .unwrap_or_default()
2759 .total_tokens() as usize;
2760
2761 Some(TotalTokenUsage { total, max })
2762 }
2763
2764 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2765 self.request_token_usage
2766 .get(self.messages.len().saturating_sub(1))
2767 .or_else(|| self.request_token_usage.last())
2768 .cloned()
2769 }
2770
2771 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2772 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2773 self.request_token_usage
2774 .resize(self.messages.len(), placeholder);
2775
2776 if let Some(last) = self.request_token_usage.last_mut() {
2777 *last = token_usage;
2778 }
2779 }
2780
2781 pub fn deny_tool_use(
2782 &mut self,
2783 tool_use_id: LanguageModelToolUseId,
2784 tool_name: Arc<str>,
2785 window: Option<AnyWindowHandle>,
2786 cx: &mut Context<Self>,
2787 ) {
2788 let err = Err(anyhow::anyhow!(
2789 "Permission to run tool action denied by user"
2790 ));
2791
2792 self.tool_use.insert_tool_output(
2793 tool_use_id.clone(),
2794 tool_name,
2795 err,
2796 self.configured_model.as_ref(),
2797 );
2798 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2799 }
2800}
2801
2802#[derive(Debug, Clone, Error)]
2803pub enum ThreadError {
2804 #[error("Payment required")]
2805 PaymentRequired,
2806 #[error("Model request limit reached")]
2807 ModelRequestLimitReached { plan: Plan },
2808 #[error("Message {header}: {message}")]
2809 Message {
2810 header: SharedString,
2811 message: SharedString,
2812 },
2813}
2814
2815#[derive(Debug, Clone)]
2816pub enum ThreadEvent {
2817 ShowError(ThreadError),
2818 StreamedCompletion,
2819 ReceivedTextChunk,
2820 NewRequest,
2821 StreamedAssistantText(MessageId, String),
2822 StreamedAssistantThinking(MessageId, String),
2823 StreamedToolUse {
2824 tool_use_id: LanguageModelToolUseId,
2825 ui_text: Arc<str>,
2826 input: serde_json::Value,
2827 },
2828 MissingToolUse {
2829 tool_use_id: LanguageModelToolUseId,
2830 ui_text: Arc<str>,
2831 },
2832 InvalidToolInput {
2833 tool_use_id: LanguageModelToolUseId,
2834 ui_text: Arc<str>,
2835 invalid_input_json: Arc<str>,
2836 },
2837 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2838 MessageAdded(MessageId),
2839 MessageEdited(MessageId),
2840 MessageDeleted(MessageId),
2841 SummaryGenerated,
2842 SummaryChanged,
2843 UsePendingTools {
2844 tool_uses: Vec<PendingToolUse>,
2845 },
2846 ToolFinished {
2847 #[allow(unused)]
2848 tool_use_id: LanguageModelToolUseId,
2849 /// The pending tool use that corresponds to this tool.
2850 pending_tool_use: Option<PendingToolUse>,
2851 },
2852 CheckpointChanged,
2853 ToolConfirmationNeeded,
2854 CancelEditing,
2855 CompletionCanceled,
2856}
2857
2858impl EventEmitter<ThreadEvent> for Thread {}
2859
2860struct PendingCompletion {
2861 id: usize,
2862 queue_state: QueueState,
2863 _task: Task<()>,
2864}
2865
2866#[cfg(test)]
2867mod tests {
2868 use super::*;
2869 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2870 use agent_settings::{AgentSettings, LanguageModelParameters};
2871 use assistant_tool::ToolRegistry;
2872 use editor::EditorSettings;
2873 use gpui::TestAppContext;
2874 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2875 use project::{FakeFs, Project};
2876 use prompt_store::PromptBuilder;
2877 use serde_json::json;
2878 use settings::{Settings, SettingsStore};
2879 use std::sync::Arc;
2880 use theme::ThemeSettings;
2881 use util::path;
2882 use workspace::Workspace;
2883
2884 #[gpui::test]
2885 async fn test_message_with_context(cx: &mut TestAppContext) {
2886 init_test_settings(cx);
2887
2888 let project = create_test_project(
2889 cx,
2890 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2891 )
2892 .await;
2893
2894 let (_workspace, _thread_store, thread, context_store, model) =
2895 setup_test_environment(cx, project.clone()).await;
2896
2897 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2898 .await
2899 .unwrap();
2900
2901 let context =
2902 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2903 let loaded_context = cx
2904 .update(|cx| load_context(vec![context], &project, &None, cx))
2905 .await;
2906
2907 // Insert user message with context
2908 let message_id = thread.update(cx, |thread, cx| {
2909 thread.insert_user_message(
2910 "Please explain this code",
2911 loaded_context,
2912 None,
2913 Vec::new(),
2914 cx,
2915 )
2916 });
2917
2918 // Check content and context in message object
2919 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2920
2921 // Use different path format strings based on platform for the test
2922 #[cfg(windows)]
2923 let path_part = r"test\code.rs";
2924 #[cfg(not(windows))]
2925 let path_part = "test/code.rs";
2926
2927 let expected_context = format!(
2928 r#"
2929<context>
2930The following items were attached by the user. They are up-to-date and don't need to be re-read.
2931
2932<files>
2933```rs {path_part}
2934fn main() {{
2935 println!("Hello, world!");
2936}}
2937```
2938</files>
2939</context>
2940"#
2941 );
2942
2943 assert_eq!(message.role, Role::User);
2944 assert_eq!(message.segments.len(), 1);
2945 assert_eq!(
2946 message.segments[0],
2947 MessageSegment::Text("Please explain this code".to_string())
2948 );
2949 assert_eq!(message.loaded_context.text, expected_context);
2950
2951 // Check message in request
2952 let request = thread.update(cx, |thread, cx| {
2953 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2954 });
2955
2956 assert_eq!(request.messages.len(), 2);
2957 let expected_full_message = format!("{}Please explain this code", expected_context);
2958 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2959 }
2960
2961 #[gpui::test]
2962 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2963 init_test_settings(cx);
2964
2965 let project = create_test_project(
2966 cx,
2967 json!({
2968 "file1.rs": "fn function1() {}\n",
2969 "file2.rs": "fn function2() {}\n",
2970 "file3.rs": "fn function3() {}\n",
2971 "file4.rs": "fn function4() {}\n",
2972 }),
2973 )
2974 .await;
2975
2976 let (_, _thread_store, thread, context_store, model) =
2977 setup_test_environment(cx, project.clone()).await;
2978
2979 // First message with context 1
2980 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2981 .await
2982 .unwrap();
2983 let new_contexts = context_store.update(cx, |store, cx| {
2984 store.new_context_for_thread(thread.read(cx), None)
2985 });
2986 assert_eq!(new_contexts.len(), 1);
2987 let loaded_context = cx
2988 .update(|cx| load_context(new_contexts, &project, &None, cx))
2989 .await;
2990 let message1_id = thread.update(cx, |thread, cx| {
2991 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2992 });
2993
2994 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2995 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2996 .await
2997 .unwrap();
2998 let new_contexts = context_store.update(cx, |store, cx| {
2999 store.new_context_for_thread(thread.read(cx), None)
3000 });
3001 assert_eq!(new_contexts.len(), 1);
3002 let loaded_context = cx
3003 .update(|cx| load_context(new_contexts, &project, &None, cx))
3004 .await;
3005 let message2_id = thread.update(cx, |thread, cx| {
3006 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3007 });
3008
3009 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3010 //
3011 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3012 .await
3013 .unwrap();
3014 let new_contexts = context_store.update(cx, |store, cx| {
3015 store.new_context_for_thread(thread.read(cx), None)
3016 });
3017 assert_eq!(new_contexts.len(), 1);
3018 let loaded_context = cx
3019 .update(|cx| load_context(new_contexts, &project, &None, cx))
3020 .await;
3021 let message3_id = thread.update(cx, |thread, cx| {
3022 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3023 });
3024
3025 // Check what contexts are included in each message
3026 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3027 (
3028 thread.message(message1_id).unwrap().clone(),
3029 thread.message(message2_id).unwrap().clone(),
3030 thread.message(message3_id).unwrap().clone(),
3031 )
3032 });
3033
3034 // First message should include context 1
3035 assert!(message1.loaded_context.text.contains("file1.rs"));
3036
3037 // Second message should include only context 2 (not 1)
3038 assert!(!message2.loaded_context.text.contains("file1.rs"));
3039 assert!(message2.loaded_context.text.contains("file2.rs"));
3040
3041 // Third message should include only context 3 (not 1 or 2)
3042 assert!(!message3.loaded_context.text.contains("file1.rs"));
3043 assert!(!message3.loaded_context.text.contains("file2.rs"));
3044 assert!(message3.loaded_context.text.contains("file3.rs"));
3045
3046 // Check entire request to make sure all contexts are properly included
3047 let request = thread.update(cx, |thread, cx| {
3048 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3049 });
3050
3051 // The request should contain all 3 messages
3052 assert_eq!(request.messages.len(), 4);
3053
3054 // Check that the contexts are properly formatted in each message
3055 assert!(request.messages[1].string_contents().contains("file1.rs"));
3056 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3057 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3058
3059 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3060 assert!(request.messages[2].string_contents().contains("file2.rs"));
3061 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3062
3063 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3064 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3065 assert!(request.messages[3].string_contents().contains("file3.rs"));
3066
3067 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3068 .await
3069 .unwrap();
3070 let new_contexts = context_store.update(cx, |store, cx| {
3071 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3072 });
3073 assert_eq!(new_contexts.len(), 3);
3074 let loaded_context = cx
3075 .update(|cx| load_context(new_contexts, &project, &None, cx))
3076 .await
3077 .loaded_context;
3078
3079 assert!(!loaded_context.text.contains("file1.rs"));
3080 assert!(loaded_context.text.contains("file2.rs"));
3081 assert!(loaded_context.text.contains("file3.rs"));
3082 assert!(loaded_context.text.contains("file4.rs"));
3083
3084 let new_contexts = context_store.update(cx, |store, cx| {
3085 // Remove file4.rs
3086 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3087 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3088 });
3089 assert_eq!(new_contexts.len(), 2);
3090 let loaded_context = cx
3091 .update(|cx| load_context(new_contexts, &project, &None, cx))
3092 .await
3093 .loaded_context;
3094
3095 assert!(!loaded_context.text.contains("file1.rs"));
3096 assert!(loaded_context.text.contains("file2.rs"));
3097 assert!(loaded_context.text.contains("file3.rs"));
3098 assert!(!loaded_context.text.contains("file4.rs"));
3099
3100 let new_contexts = context_store.update(cx, |store, cx| {
3101 // Remove file3.rs
3102 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3103 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3104 });
3105 assert_eq!(new_contexts.len(), 1);
3106 let loaded_context = cx
3107 .update(|cx| load_context(new_contexts, &project, &None, cx))
3108 .await
3109 .loaded_context;
3110
3111 assert!(!loaded_context.text.contains("file1.rs"));
3112 assert!(loaded_context.text.contains("file2.rs"));
3113 assert!(!loaded_context.text.contains("file3.rs"));
3114 assert!(!loaded_context.text.contains("file4.rs"));
3115 }
3116
3117 #[gpui::test]
3118 async fn test_message_without_files(cx: &mut TestAppContext) {
3119 init_test_settings(cx);
3120
3121 let project = create_test_project(
3122 cx,
3123 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3124 )
3125 .await;
3126
3127 let (_, _thread_store, thread, _context_store, model) =
3128 setup_test_environment(cx, project.clone()).await;
3129
3130 // Insert user message without any context (empty context vector)
3131 let message_id = thread.update(cx, |thread, cx| {
3132 thread.insert_user_message(
3133 "What is the best way to learn Rust?",
3134 ContextLoadResult::default(),
3135 None,
3136 Vec::new(),
3137 cx,
3138 )
3139 });
3140
3141 // Check content and context in message object
3142 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3143
3144 // Context should be empty when no files are included
3145 assert_eq!(message.role, Role::User);
3146 assert_eq!(message.segments.len(), 1);
3147 assert_eq!(
3148 message.segments[0],
3149 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3150 );
3151 assert_eq!(message.loaded_context.text, "");
3152
3153 // Check message in request
3154 let request = thread.update(cx, |thread, cx| {
3155 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3156 });
3157
3158 assert_eq!(request.messages.len(), 2);
3159 assert_eq!(
3160 request.messages[1].string_contents(),
3161 "What is the best way to learn Rust?"
3162 );
3163
3164 // Add second message, also without context
3165 let message2_id = thread.update(cx, |thread, cx| {
3166 thread.insert_user_message(
3167 "Are there any good books?",
3168 ContextLoadResult::default(),
3169 None,
3170 Vec::new(),
3171 cx,
3172 )
3173 });
3174
3175 let message2 =
3176 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3177 assert_eq!(message2.loaded_context.text, "");
3178
3179 // Check that both messages appear in the request
3180 let request = thread.update(cx, |thread, cx| {
3181 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3182 });
3183
3184 assert_eq!(request.messages.len(), 3);
3185 assert_eq!(
3186 request.messages[1].string_contents(),
3187 "What is the best way to learn Rust?"
3188 );
3189 assert_eq!(
3190 request.messages[2].string_contents(),
3191 "Are there any good books?"
3192 );
3193 }
3194
3195 #[gpui::test]
3196 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3197 init_test_settings(cx);
3198
3199 let project = create_test_project(
3200 cx,
3201 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3202 )
3203 .await;
3204
3205 let (_workspace, _thread_store, thread, context_store, model) =
3206 setup_test_environment(cx, project.clone()).await;
3207
3208 // Open buffer and add it to context
3209 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3210 .await
3211 .unwrap();
3212
3213 let context =
3214 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3215 let loaded_context = cx
3216 .update(|cx| load_context(vec![context], &project, &None, cx))
3217 .await;
3218
3219 // Insert user message with the buffer as context
3220 thread.update(cx, |thread, cx| {
3221 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3222 });
3223
3224 // Create a request and check that it doesn't have a stale buffer warning yet
3225 let initial_request = thread.update(cx, |thread, cx| {
3226 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3227 });
3228
3229 // Make sure we don't have a stale file warning yet
3230 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3231 msg.string_contents()
3232 .contains("These files changed since last read:")
3233 });
3234 assert!(
3235 !has_stale_warning,
3236 "Should not have stale buffer warning before buffer is modified"
3237 );
3238
3239 // Modify the buffer
3240 buffer.update(cx, |buffer, cx| {
3241 // Find a position at the end of line 1
3242 buffer.edit(
3243 [(1..1, "\n println!(\"Added a new line\");\n")],
3244 None,
3245 cx,
3246 );
3247 });
3248
3249 // Insert another user message without context
3250 thread.update(cx, |thread, cx| {
3251 thread.insert_user_message(
3252 "What does the code do now?",
3253 ContextLoadResult::default(),
3254 None,
3255 Vec::new(),
3256 cx,
3257 )
3258 });
3259
3260 // Create a new request and check for the stale buffer warning
3261 let new_request = thread.update(cx, |thread, cx| {
3262 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3263 });
3264
3265 // We should have a stale file warning as the last message
3266 let last_message = new_request
3267 .messages
3268 .last()
3269 .expect("Request should have messages");
3270
3271 // The last message should be the stale buffer notification
3272 assert_eq!(last_message.role, Role::User);
3273
3274 // Check the exact content of the message
3275 let expected_content = "These files changed since last read:\n- code.rs\n";
3276 assert_eq!(
3277 last_message.string_contents(),
3278 expected_content,
3279 "Last message should be exactly the stale buffer notification"
3280 );
3281 }
3282
3283 #[gpui::test]
3284 async fn test_temperature_setting(cx: &mut TestAppContext) {
3285 init_test_settings(cx);
3286
3287 let project = create_test_project(
3288 cx,
3289 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3290 )
3291 .await;
3292
3293 let (_workspace, _thread_store, thread, _context_store, model) =
3294 setup_test_environment(cx, project.clone()).await;
3295
3296 // Both model and provider
3297 cx.update(|cx| {
3298 AgentSettings::override_global(
3299 AgentSettings {
3300 model_parameters: vec![LanguageModelParameters {
3301 provider: Some(model.provider_id().0.to_string().into()),
3302 model: Some(model.id().0.clone()),
3303 temperature: Some(0.66),
3304 }],
3305 ..AgentSettings::get_global(cx).clone()
3306 },
3307 cx,
3308 );
3309 });
3310
3311 let request = thread.update(cx, |thread, cx| {
3312 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3313 });
3314 assert_eq!(request.temperature, Some(0.66));
3315
3316 // Only model
3317 cx.update(|cx| {
3318 AgentSettings::override_global(
3319 AgentSettings {
3320 model_parameters: vec![LanguageModelParameters {
3321 provider: None,
3322 model: Some(model.id().0.clone()),
3323 temperature: Some(0.66),
3324 }],
3325 ..AgentSettings::get_global(cx).clone()
3326 },
3327 cx,
3328 );
3329 });
3330
3331 let request = thread.update(cx, |thread, cx| {
3332 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3333 });
3334 assert_eq!(request.temperature, Some(0.66));
3335
3336 // Only provider
3337 cx.update(|cx| {
3338 AgentSettings::override_global(
3339 AgentSettings {
3340 model_parameters: vec![LanguageModelParameters {
3341 provider: Some(model.provider_id().0.to_string().into()),
3342 model: None,
3343 temperature: Some(0.66),
3344 }],
3345 ..AgentSettings::get_global(cx).clone()
3346 },
3347 cx,
3348 );
3349 });
3350
3351 let request = thread.update(cx, |thread, cx| {
3352 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3353 });
3354 assert_eq!(request.temperature, Some(0.66));
3355
3356 // Same model name, different provider
3357 cx.update(|cx| {
3358 AgentSettings::override_global(
3359 AgentSettings {
3360 model_parameters: vec![LanguageModelParameters {
3361 provider: Some("anthropic".into()),
3362 model: Some(model.id().0.clone()),
3363 temperature: Some(0.66),
3364 }],
3365 ..AgentSettings::get_global(cx).clone()
3366 },
3367 cx,
3368 );
3369 });
3370
3371 let request = thread.update(cx, |thread, cx| {
3372 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3373 });
3374 assert_eq!(request.temperature, None);
3375 }
3376
3377 #[gpui::test]
3378 async fn test_thread_summary(cx: &mut TestAppContext) {
3379 init_test_settings(cx);
3380
3381 let project = create_test_project(cx, json!({})).await;
3382
3383 let (_, _thread_store, thread, _context_store, model) =
3384 setup_test_environment(cx, project.clone()).await;
3385
3386 // Initial state should be pending
3387 thread.read_with(cx, |thread, _| {
3388 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3389 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3390 });
3391
3392 // Manually setting the summary should not be allowed in this state
3393 thread.update(cx, |thread, cx| {
3394 thread.set_summary("This should not work", cx);
3395 });
3396
3397 thread.read_with(cx, |thread, _| {
3398 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3399 });
3400
3401 // Send a message
3402 thread.update(cx, |thread, cx| {
3403 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3404 thread.send_to_model(
3405 model.clone(),
3406 CompletionIntent::ThreadSummarization,
3407 None,
3408 cx,
3409 );
3410 });
3411
3412 let fake_model = model.as_fake();
3413 simulate_successful_response(&fake_model, cx);
3414
3415 // Should start generating summary when there are >= 2 messages
3416 thread.read_with(cx, |thread, _| {
3417 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3418 });
3419
3420 // Should not be able to set the summary while generating
3421 thread.update(cx, |thread, cx| {
3422 thread.set_summary("This should not work either", cx);
3423 });
3424
3425 thread.read_with(cx, |thread, _| {
3426 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3427 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3428 });
3429
3430 cx.run_until_parked();
3431 fake_model.stream_last_completion_response("Brief");
3432 fake_model.stream_last_completion_response(" Introduction");
3433 fake_model.end_last_completion_stream();
3434 cx.run_until_parked();
3435
3436 // Summary should be set
3437 thread.read_with(cx, |thread, _| {
3438 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3439 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3440 });
3441
3442 // Now we should be able to set a summary
3443 thread.update(cx, |thread, cx| {
3444 thread.set_summary("Brief Intro", cx);
3445 });
3446
3447 thread.read_with(cx, |thread, _| {
3448 assert_eq!(thread.summary().or_default(), "Brief Intro");
3449 });
3450
3451 // Test setting an empty summary (should default to DEFAULT)
3452 thread.update(cx, |thread, cx| {
3453 thread.set_summary("", cx);
3454 });
3455
3456 thread.read_with(cx, |thread, _| {
3457 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3458 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3459 });
3460 }
3461
3462 #[gpui::test]
3463 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3464 init_test_settings(cx);
3465
3466 let project = create_test_project(cx, json!({})).await;
3467
3468 let (_, _thread_store, thread, _context_store, model) =
3469 setup_test_environment(cx, project.clone()).await;
3470
3471 test_summarize_error(&model, &thread, cx);
3472
3473 // Now we should be able to set a summary
3474 thread.update(cx, |thread, cx| {
3475 thread.set_summary("Brief Intro", cx);
3476 });
3477
3478 thread.read_with(cx, |thread, _| {
3479 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3480 assert_eq!(thread.summary().or_default(), "Brief Intro");
3481 });
3482 }
3483
3484 #[gpui::test]
3485 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3486 init_test_settings(cx);
3487
3488 let project = create_test_project(cx, json!({})).await;
3489
3490 let (_, _thread_store, thread, _context_store, model) =
3491 setup_test_environment(cx, project.clone()).await;
3492
3493 test_summarize_error(&model, &thread, cx);
3494
3495 // Sending another message should not trigger another summarize request
3496 thread.update(cx, |thread, cx| {
3497 thread.insert_user_message(
3498 "How are you?",
3499 ContextLoadResult::default(),
3500 None,
3501 vec![],
3502 cx,
3503 );
3504 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3505 });
3506
3507 let fake_model = model.as_fake();
3508 simulate_successful_response(&fake_model, cx);
3509
3510 thread.read_with(cx, |thread, _| {
3511 // State is still Error, not Generating
3512 assert!(matches!(thread.summary(), ThreadSummary::Error));
3513 });
3514
3515 // But the summarize request can be invoked manually
3516 thread.update(cx, |thread, cx| {
3517 thread.summarize(cx);
3518 });
3519
3520 thread.read_with(cx, |thread, _| {
3521 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3522 });
3523
3524 cx.run_until_parked();
3525 fake_model.stream_last_completion_response("A successful summary");
3526 fake_model.end_last_completion_stream();
3527 cx.run_until_parked();
3528
3529 thread.read_with(cx, |thread, _| {
3530 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3531 assert_eq!(thread.summary().or_default(), "A successful summary");
3532 });
3533 }
3534
3535 fn test_summarize_error(
3536 model: &Arc<dyn LanguageModel>,
3537 thread: &Entity<Thread>,
3538 cx: &mut TestAppContext,
3539 ) {
3540 thread.update(cx, |thread, cx| {
3541 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3542 thread.send_to_model(
3543 model.clone(),
3544 CompletionIntent::ThreadSummarization,
3545 None,
3546 cx,
3547 );
3548 });
3549
3550 let fake_model = model.as_fake();
3551 simulate_successful_response(&fake_model, cx);
3552
3553 thread.read_with(cx, |thread, _| {
3554 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3555 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3556 });
3557
3558 // Simulate summary request ending
3559 cx.run_until_parked();
3560 fake_model.end_last_completion_stream();
3561 cx.run_until_parked();
3562
3563 // State is set to Error and default message
3564 thread.read_with(cx, |thread, _| {
3565 assert!(matches!(thread.summary(), ThreadSummary::Error));
3566 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3567 });
3568 }
3569
3570 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3571 cx.run_until_parked();
3572 fake_model.stream_last_completion_response("Assistant response");
3573 fake_model.end_last_completion_stream();
3574 cx.run_until_parked();
3575 }
3576
3577 fn init_test_settings(cx: &mut TestAppContext) {
3578 cx.update(|cx| {
3579 let settings_store = SettingsStore::test(cx);
3580 cx.set_global(settings_store);
3581 language::init(cx);
3582 Project::init_settings(cx);
3583 AgentSettings::register(cx);
3584 prompt_store::init(cx);
3585 thread_store::init(cx);
3586 workspace::init_settings(cx);
3587 language_model::init_settings(cx);
3588 ThemeSettings::register(cx);
3589 EditorSettings::register(cx);
3590 ToolRegistry::default_global(cx);
3591 });
3592 }
3593
3594 // Helper to create a test project with test files
3595 async fn create_test_project(
3596 cx: &mut TestAppContext,
3597 files: serde_json::Value,
3598 ) -> Entity<Project> {
3599 let fs = FakeFs::new(cx.executor());
3600 fs.insert_tree(path!("/test"), files).await;
3601 Project::test(fs, [path!("/test").as_ref()], cx).await
3602 }
3603
3604 async fn setup_test_environment(
3605 cx: &mut TestAppContext,
3606 project: Entity<Project>,
3607 ) -> (
3608 Entity<Workspace>,
3609 Entity<ThreadStore>,
3610 Entity<Thread>,
3611 Entity<ContextStore>,
3612 Arc<dyn LanguageModel>,
3613 ) {
3614 let (workspace, cx) =
3615 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3616
3617 let thread_store = cx
3618 .update(|_, cx| {
3619 ThreadStore::load(
3620 project.clone(),
3621 cx.new(|_| ToolWorkingSet::default()),
3622 None,
3623 Arc::new(PromptBuilder::new(None).unwrap()),
3624 cx,
3625 )
3626 })
3627 .await
3628 .unwrap();
3629
3630 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3631 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3632
3633 let provider = Arc::new(FakeLanguageModelProvider);
3634 let model = provider.test_model();
3635 let model: Arc<dyn LanguageModel> = Arc::new(model);
3636
3637 cx.update(|_, cx| {
3638 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3639 registry.set_default_model(
3640 Some(ConfiguredModel {
3641 provider: provider.clone(),
3642 model: model.clone(),
3643 }),
3644 cx,
3645 );
3646 registry.set_thread_summary_model(
3647 Some(ConfiguredModel {
3648 provider,
3649 model: model.clone(),
3650 }),
3651 cx,
3652 );
3653 })
3654 });
3655
3656 (workspace, thread_store, thread, context_store, model)
3657 }
3658
3659 async fn add_file_to_context(
3660 project: &Entity<Project>,
3661 context_store: &Entity<ContextStore>,
3662 path: &str,
3663 cx: &mut TestAppContext,
3664 ) -> Result<Entity<language::Buffer>> {
3665 let buffer_path = project
3666 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3667 .unwrap();
3668
3669 let buffer = project
3670 .update(cx, |project, cx| {
3671 project.open_buffer(buffer_path.clone(), cx)
3672 })
3673 .await
3674 .unwrap();
3675
3676 context_store.update(cx, |context_store, cx| {
3677 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3678 });
3679
3680 Ok(buffer)
3681 }
3682}