1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(text)
885 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
886 Some(text)
887 }
888 LanguageModelToolResultContent::Image(_) => {
889 // TODO: We should display image
890 None
891 }
892 }
893 }
894
895 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
896 self.tool_use.tool_result_card(id).cloned()
897 }
898
899 /// Return tools that are both enabled and supported by the model
900 pub fn available_tools(
901 &self,
902 cx: &App,
903 model: Arc<dyn LanguageModel>,
904 ) -> Vec<LanguageModelRequestTool> {
905 if model.supports_tools() {
906 self.tools()
907 .read(cx)
908 .enabled_tools(cx)
909 .into_iter()
910 .filter_map(|tool| {
911 // Skip tools that cannot be supported
912 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
913 Some(LanguageModelRequestTool {
914 name: tool.name(),
915 description: tool.description(),
916 input_schema,
917 })
918 })
919 .collect()
920 } else {
921 Vec::default()
922 }
923 }
924
925 pub fn insert_user_message(
926 &mut self,
927 text: impl Into<String>,
928 loaded_context: ContextLoadResult,
929 git_checkpoint: Option<GitStoreCheckpoint>,
930 creases: Vec<MessageCrease>,
931 cx: &mut Context<Self>,
932 ) -> MessageId {
933 if !loaded_context.referenced_buffers.is_empty() {
934 self.action_log.update(cx, |log, cx| {
935 for buffer in loaded_context.referenced_buffers {
936 log.buffer_read(buffer, cx);
937 }
938 });
939 }
940
941 let message_id = self.insert_message(
942 Role::User,
943 vec![MessageSegment::Text(text.into())],
944 loaded_context.loaded_context,
945 creases,
946 cx,
947 );
948
949 if let Some(git_checkpoint) = git_checkpoint {
950 self.pending_checkpoint = Some(ThreadCheckpoint {
951 message_id,
952 git_checkpoint,
953 });
954 }
955
956 self.auto_capture_telemetry(cx);
957
958 message_id
959 }
960
961 pub fn insert_assistant_message(
962 &mut self,
963 segments: Vec<MessageSegment>,
964 cx: &mut Context<Self>,
965 ) -> MessageId {
966 self.insert_message(
967 Role::Assistant,
968 segments,
969 LoadedContext::default(),
970 Vec::new(),
971 cx,
972 )
973 }
974
975 pub fn insert_message(
976 &mut self,
977 role: Role,
978 segments: Vec<MessageSegment>,
979 loaded_context: LoadedContext,
980 creases: Vec<MessageCrease>,
981 cx: &mut Context<Self>,
982 ) -> MessageId {
983 let id = self.next_message_id.post_inc();
984 self.messages.push(Message {
985 id,
986 role,
987 segments,
988 loaded_context,
989 creases,
990 });
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageAdded(id));
993 id
994 }
995
996 pub fn edit_message(
997 &mut self,
998 id: MessageId,
999 new_role: Role,
1000 new_segments: Vec<MessageSegment>,
1001 loaded_context: Option<LoadedContext>,
1002 checkpoint: Option<GitStoreCheckpoint>,
1003 cx: &mut Context<Self>,
1004 ) -> bool {
1005 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1006 return false;
1007 };
1008 message.role = new_role;
1009 message.segments = new_segments;
1010 if let Some(context) = loaded_context {
1011 message.loaded_context = context;
1012 }
1013 if let Some(git_checkpoint) = checkpoint {
1014 self.checkpoints_by_message.insert(
1015 id,
1016 ThreadCheckpoint {
1017 message_id: id,
1018 git_checkpoint,
1019 },
1020 );
1021 }
1022 self.touch_updated_at();
1023 cx.emit(ThreadEvent::MessageEdited(id));
1024 true
1025 }
1026
1027 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1028 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1029 return false;
1030 };
1031 self.messages.remove(index);
1032 self.touch_updated_at();
1033 cx.emit(ThreadEvent::MessageDeleted(id));
1034 true
1035 }
1036
1037 /// Returns the representation of this [`Thread`] in a textual form.
1038 ///
1039 /// This is the representation we use when attaching a thread as context to another thread.
1040 pub fn text(&self) -> String {
1041 let mut text = String::new();
1042
1043 for message in &self.messages {
1044 text.push_str(match message.role {
1045 language_model::Role::User => "User:",
1046 language_model::Role::Assistant => "Agent:",
1047 language_model::Role::System => "System:",
1048 });
1049 text.push('\n');
1050
1051 for segment in &message.segments {
1052 match segment {
1053 MessageSegment::Text(content) => text.push_str(content),
1054 MessageSegment::Thinking { text: content, .. } => {
1055 text.push_str(&format!("<think>{}</think>", content))
1056 }
1057 MessageSegment::RedactedThinking(_) => {}
1058 }
1059 }
1060 text.push('\n');
1061 }
1062
1063 text
1064 }
1065
1066 /// Serializes this thread into a format for storage or telemetry.
1067 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1068 let initial_project_snapshot = self.initial_project_snapshot.clone();
1069 cx.spawn(async move |this, cx| {
1070 let initial_project_snapshot = initial_project_snapshot.await;
1071 this.read_with(cx, |this, cx| SerializedThread {
1072 version: SerializedThread::VERSION.to_string(),
1073 summary: this.summary().or_default(),
1074 updated_at: this.updated_at(),
1075 messages: this
1076 .messages()
1077 .map(|message| SerializedMessage {
1078 id: message.id,
1079 role: message.role,
1080 segments: message
1081 .segments
1082 .iter()
1083 .map(|segment| match segment {
1084 MessageSegment::Text(text) => {
1085 SerializedMessageSegment::Text { text: text.clone() }
1086 }
1087 MessageSegment::Thinking { text, signature } => {
1088 SerializedMessageSegment::Thinking {
1089 text: text.clone(),
1090 signature: signature.clone(),
1091 }
1092 }
1093 MessageSegment::RedactedThinking(data) => {
1094 SerializedMessageSegment::RedactedThinking {
1095 data: data.clone(),
1096 }
1097 }
1098 })
1099 .collect(),
1100 tool_uses: this
1101 .tool_uses_for_message(message.id, cx)
1102 .into_iter()
1103 .map(|tool_use| SerializedToolUse {
1104 id: tool_use.id,
1105 name: tool_use.name,
1106 input: tool_use.input,
1107 })
1108 .collect(),
1109 tool_results: this
1110 .tool_results_for_message(message.id)
1111 .into_iter()
1112 .map(|tool_result| SerializedToolResult {
1113 tool_use_id: tool_result.tool_use_id.clone(),
1114 is_error: tool_result.is_error,
1115 content: tool_result.content.clone(),
1116 output: tool_result.output.clone(),
1117 })
1118 .collect(),
1119 context: message.loaded_context.text.clone(),
1120 creases: message
1121 .creases
1122 .iter()
1123 .map(|crease| SerializedCrease {
1124 start: crease.range.start,
1125 end: crease.range.end,
1126 icon_path: crease.metadata.icon_path.clone(),
1127 label: crease.metadata.label.clone(),
1128 })
1129 .collect(),
1130 })
1131 .collect(),
1132 initial_project_snapshot,
1133 cumulative_token_usage: this.cumulative_token_usage,
1134 request_token_usage: this.request_token_usage.clone(),
1135 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1136 exceeded_window_error: this.exceeded_window_error.clone(),
1137 model: this
1138 .configured_model
1139 .as_ref()
1140 .map(|model| SerializedLanguageModel {
1141 provider: model.provider.id().0.to_string(),
1142 model: model.model.id().0.to_string(),
1143 }),
1144 completion_mode: Some(this.completion_mode),
1145 })
1146 })
1147 }
1148
1149 pub fn remaining_turns(&self) -> u32 {
1150 self.remaining_turns
1151 }
1152
1153 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1154 self.remaining_turns = remaining_turns;
1155 }
1156
1157 pub fn send_to_model(
1158 &mut self,
1159 model: Arc<dyn LanguageModel>,
1160 intent: CompletionIntent,
1161 window: Option<AnyWindowHandle>,
1162 cx: &mut Context<Self>,
1163 ) {
1164 if self.remaining_turns == 0 {
1165 return;
1166 }
1167
1168 self.remaining_turns -= 1;
1169
1170 let request = self.to_completion_request(model.clone(), intent, cx);
1171
1172 self.stream_completion(request, model, window, cx);
1173 }
1174
1175 pub fn used_tools_since_last_user_message(&self) -> bool {
1176 for message in self.messages.iter().rev() {
1177 if self.tool_use.message_has_tool_results(message.id) {
1178 return true;
1179 } else if message.role == Role::User {
1180 return false;
1181 }
1182 }
1183
1184 false
1185 }
1186
1187 pub fn to_completion_request(
1188 &self,
1189 model: Arc<dyn LanguageModel>,
1190 intent: CompletionIntent,
1191 cx: &mut Context<Self>,
1192 ) -> LanguageModelRequest {
1193 let mut request = LanguageModelRequest {
1194 thread_id: Some(self.id.to_string()),
1195 prompt_id: Some(self.last_prompt_id.to_string()),
1196 intent: Some(intent),
1197 mode: None,
1198 messages: vec![],
1199 tools: Vec::new(),
1200 tool_choice: None,
1201 stop: Vec::new(),
1202 temperature: AssistantSettings::temperature_for_model(&model, cx),
1203 };
1204
1205 let available_tools = self.available_tools(cx, model.clone());
1206 let available_tool_names = available_tools
1207 .iter()
1208 .map(|tool| tool.name.clone())
1209 .collect();
1210
1211 let model_context = &ModelContext {
1212 available_tools: available_tool_names,
1213 };
1214
1215 if let Some(project_context) = self.project_context.borrow().as_ref() {
1216 match self
1217 .prompt_builder
1218 .generate_assistant_system_prompt(project_context, model_context)
1219 {
1220 Err(err) => {
1221 let message = format!("{err:?}").into();
1222 log::error!("{message}");
1223 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1224 header: "Error generating system prompt".into(),
1225 message,
1226 }));
1227 }
1228 Ok(system_prompt) => {
1229 request.messages.push(LanguageModelRequestMessage {
1230 role: Role::System,
1231 content: vec![MessageContent::Text(system_prompt)],
1232 cache: true,
1233 });
1234 }
1235 }
1236 } else {
1237 let message = "Context for system prompt unexpectedly not ready.".into();
1238 log::error!("{message}");
1239 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1240 header: "Error generating system prompt".into(),
1241 message,
1242 }));
1243 }
1244
1245 let mut message_ix_to_cache = None;
1246 for message in &self.messages {
1247 let mut request_message = LanguageModelRequestMessage {
1248 role: message.role,
1249 content: Vec::new(),
1250 cache: false,
1251 };
1252
1253 message
1254 .loaded_context
1255 .add_to_request_message(&mut request_message);
1256
1257 for segment in &message.segments {
1258 match segment {
1259 MessageSegment::Text(text) => {
1260 if !text.is_empty() {
1261 request_message
1262 .content
1263 .push(MessageContent::Text(text.into()));
1264 }
1265 }
1266 MessageSegment::Thinking { text, signature } => {
1267 if !text.is_empty() {
1268 request_message.content.push(MessageContent::Thinking {
1269 text: text.into(),
1270 signature: signature.clone(),
1271 });
1272 }
1273 }
1274 MessageSegment::RedactedThinking(data) => {
1275 request_message
1276 .content
1277 .push(MessageContent::RedactedThinking(data.clone()));
1278 }
1279 };
1280 }
1281
1282 let mut cache_message = true;
1283 let mut tool_results_message = LanguageModelRequestMessage {
1284 role: Role::User,
1285 content: Vec::new(),
1286 cache: false,
1287 };
1288 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1289 if let Some(tool_result) = tool_result {
1290 request_message
1291 .content
1292 .push(MessageContent::ToolUse(tool_use.clone()));
1293 tool_results_message
1294 .content
1295 .push(MessageContent::ToolResult(LanguageModelToolResult {
1296 tool_use_id: tool_use.id.clone(),
1297 tool_name: tool_result.tool_name.clone(),
1298 is_error: tool_result.is_error,
1299 content: if tool_result.content.is_empty() {
1300 // Surprisingly, the API fails if we return an empty string here.
1301 // It thinks we are sending a tool use without a tool result.
1302 "<Tool returned an empty string>".into()
1303 } else {
1304 tool_result.content.clone()
1305 },
1306 output: None,
1307 }));
1308 } else {
1309 cache_message = false;
1310 log::debug!(
1311 "skipped tool use {:?} because it is still pending",
1312 tool_use
1313 );
1314 }
1315 }
1316
1317 if cache_message {
1318 message_ix_to_cache = Some(request.messages.len());
1319 }
1320 request.messages.push(request_message);
1321
1322 if !tool_results_message.content.is_empty() {
1323 if cache_message {
1324 message_ix_to_cache = Some(request.messages.len());
1325 }
1326 request.messages.push(tool_results_message);
1327 }
1328 }
1329
1330 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1331 if let Some(message_ix_to_cache) = message_ix_to_cache {
1332 request.messages[message_ix_to_cache].cache = true;
1333 }
1334
1335 self.attached_tracked_files_state(&mut request.messages, cx);
1336
1337 request.tools = available_tools;
1338 request.mode = if model.supports_max_mode() {
1339 Some(self.completion_mode.into())
1340 } else {
1341 Some(CompletionMode::Normal.into())
1342 };
1343
1344 request
1345 }
1346
1347 fn to_summarize_request(
1348 &self,
1349 model: &Arc<dyn LanguageModel>,
1350 intent: CompletionIntent,
1351 added_user_message: String,
1352 cx: &App,
1353 ) -> LanguageModelRequest {
1354 let mut request = LanguageModelRequest {
1355 thread_id: None,
1356 prompt_id: None,
1357 intent: Some(intent),
1358 mode: None,
1359 messages: vec![],
1360 tools: Vec::new(),
1361 tool_choice: None,
1362 stop: Vec::new(),
1363 temperature: AssistantSettings::temperature_for_model(model, cx),
1364 };
1365
1366 for message in &self.messages {
1367 let mut request_message = LanguageModelRequestMessage {
1368 role: message.role,
1369 content: Vec::new(),
1370 cache: false,
1371 };
1372
1373 for segment in &message.segments {
1374 match segment {
1375 MessageSegment::Text(text) => request_message
1376 .content
1377 .push(MessageContent::Text(text.clone())),
1378 MessageSegment::Thinking { .. } => {}
1379 MessageSegment::RedactedThinking(_) => {}
1380 }
1381 }
1382
1383 if request_message.content.is_empty() {
1384 continue;
1385 }
1386
1387 request.messages.push(request_message);
1388 }
1389
1390 request.messages.push(LanguageModelRequestMessage {
1391 role: Role::User,
1392 content: vec![MessageContent::Text(added_user_message)],
1393 cache: false,
1394 });
1395
1396 request
1397 }
1398
1399 fn attached_tracked_files_state(
1400 &self,
1401 messages: &mut Vec<LanguageModelRequestMessage>,
1402 cx: &App,
1403 ) {
1404 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1405
1406 let mut stale_message = String::new();
1407
1408 let action_log = self.action_log.read(cx);
1409
1410 for stale_file in action_log.stale_buffers(cx) {
1411 let Some(file) = stale_file.read(cx).file() else {
1412 continue;
1413 };
1414
1415 if stale_message.is_empty() {
1416 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1417 }
1418
1419 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1420 }
1421
1422 let mut content = Vec::with_capacity(2);
1423
1424 if !stale_message.is_empty() {
1425 content.push(stale_message.into());
1426 }
1427
1428 if !content.is_empty() {
1429 let context_message = LanguageModelRequestMessage {
1430 role: Role::User,
1431 content,
1432 cache: false,
1433 };
1434
1435 messages.push(context_message);
1436 }
1437 }
1438
1439 pub fn stream_completion(
1440 &mut self,
1441 request: LanguageModelRequest,
1442 model: Arc<dyn LanguageModel>,
1443 window: Option<AnyWindowHandle>,
1444 cx: &mut Context<Self>,
1445 ) {
1446 self.tool_use_limit_reached = false;
1447
1448 let pending_completion_id = post_inc(&mut self.completion_count);
1449 let mut request_callback_parameters = if self.request_callback.is_some() {
1450 Some((request.clone(), Vec::new()))
1451 } else {
1452 None
1453 };
1454 let prompt_id = self.last_prompt_id.clone();
1455 let tool_use_metadata = ToolUseMetadata {
1456 model: model.clone(),
1457 thread_id: self.id.clone(),
1458 prompt_id: prompt_id.clone(),
1459 };
1460
1461 self.last_received_chunk_at = Some(Instant::now());
1462
1463 let task = cx.spawn(async move |thread, cx| {
1464 let stream_completion_future = model.stream_completion(request, &cx);
1465 let initial_token_usage =
1466 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1467 let stream_completion = async {
1468 let mut events = stream_completion_future.await?;
1469
1470 let mut stop_reason = StopReason::EndTurn;
1471 let mut current_token_usage = TokenUsage::default();
1472
1473 thread
1474 .update(cx, |_thread, cx| {
1475 cx.emit(ThreadEvent::NewRequest);
1476 })
1477 .ok();
1478
1479 let mut request_assistant_message_id = None;
1480
1481 while let Some(event) = events.next().await {
1482 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1483 response_events
1484 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1485 }
1486
1487 thread.update(cx, |thread, cx| {
1488 let event = match event {
1489 Ok(event) => event,
1490 Err(LanguageModelCompletionError::BadInputJson {
1491 id,
1492 tool_name,
1493 raw_input: invalid_input_json,
1494 json_parse_error,
1495 }) => {
1496 thread.receive_invalid_tool_json(
1497 id,
1498 tool_name,
1499 invalid_input_json,
1500 json_parse_error,
1501 window,
1502 cx,
1503 );
1504 return Ok(());
1505 }
1506 Err(LanguageModelCompletionError::Other(error)) => {
1507 return Err(error);
1508 }
1509 };
1510
1511 match event {
1512 LanguageModelCompletionEvent::StartMessage { .. } => {
1513 request_assistant_message_id =
1514 Some(thread.insert_assistant_message(
1515 vec![MessageSegment::Text(String::new())],
1516 cx,
1517 ));
1518 }
1519 LanguageModelCompletionEvent::Stop(reason) => {
1520 stop_reason = reason;
1521 }
1522 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1523 thread.update_token_usage_at_last_message(token_usage);
1524 thread.cumulative_token_usage = thread.cumulative_token_usage
1525 + token_usage
1526 - current_token_usage;
1527 current_token_usage = token_usage;
1528 }
1529 LanguageModelCompletionEvent::Text(chunk) => {
1530 thread.received_chunk();
1531
1532 cx.emit(ThreadEvent::ReceivedTextChunk);
1533 if let Some(last_message) = thread.messages.last_mut() {
1534 if last_message.role == Role::Assistant
1535 && !thread.tool_use.has_tool_results(last_message.id)
1536 {
1537 last_message.push_text(&chunk);
1538 cx.emit(ThreadEvent::StreamedAssistantText(
1539 last_message.id,
1540 chunk,
1541 ));
1542 } else {
1543 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1544 // of a new Assistant response.
1545 //
1546 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1547 // will result in duplicating the text of the chunk in the rendered Markdown.
1548 request_assistant_message_id =
1549 Some(thread.insert_assistant_message(
1550 vec![MessageSegment::Text(chunk.to_string())],
1551 cx,
1552 ));
1553 };
1554 }
1555 }
1556 LanguageModelCompletionEvent::Thinking {
1557 text: chunk,
1558 signature,
1559 } => {
1560 thread.received_chunk();
1561
1562 if let Some(last_message) = thread.messages.last_mut() {
1563 if last_message.role == Role::Assistant
1564 && !thread.tool_use.has_tool_results(last_message.id)
1565 {
1566 last_message.push_thinking(&chunk, signature);
1567 cx.emit(ThreadEvent::StreamedAssistantThinking(
1568 last_message.id,
1569 chunk,
1570 ));
1571 } else {
1572 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1573 // of a new Assistant response.
1574 //
1575 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1576 // will result in duplicating the text of the chunk in the rendered Markdown.
1577 request_assistant_message_id =
1578 Some(thread.insert_assistant_message(
1579 vec![MessageSegment::Thinking {
1580 text: chunk.to_string(),
1581 signature,
1582 }],
1583 cx,
1584 ));
1585 };
1586 }
1587 }
1588 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1589 let last_assistant_message_id = request_assistant_message_id
1590 .unwrap_or_else(|| {
1591 let new_assistant_message_id =
1592 thread.insert_assistant_message(vec![], cx);
1593 request_assistant_message_id =
1594 Some(new_assistant_message_id);
1595 new_assistant_message_id
1596 });
1597
1598 let tool_use_id = tool_use.id.clone();
1599 let streamed_input = if tool_use.is_input_complete {
1600 None
1601 } else {
1602 Some((&tool_use.input).clone())
1603 };
1604
1605 let ui_text = thread.tool_use.request_tool_use(
1606 last_assistant_message_id,
1607 tool_use,
1608 tool_use_metadata.clone(),
1609 cx,
1610 );
1611
1612 if let Some(input) = streamed_input {
1613 cx.emit(ThreadEvent::StreamedToolUse {
1614 tool_use_id,
1615 ui_text,
1616 input,
1617 });
1618 }
1619 }
1620 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1621 if let Some(completion) = thread
1622 .pending_completions
1623 .iter_mut()
1624 .find(|completion| completion.id == pending_completion_id)
1625 {
1626 match status_update {
1627 CompletionRequestStatus::Queued {
1628 position,
1629 } => {
1630 completion.queue_state = QueueState::Queued { position };
1631 }
1632 CompletionRequestStatus::Started => {
1633 completion.queue_state = QueueState::Started;
1634 }
1635 CompletionRequestStatus::Failed {
1636 code, message, request_id
1637 } => {
1638 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1639 }
1640 CompletionRequestStatus::UsageUpdated {
1641 amount, limit
1642 } => {
1643 let usage = RequestUsage { limit, amount: amount as i32 };
1644
1645 thread.last_usage = Some(usage);
1646 }
1647 CompletionRequestStatus::ToolUseLimitReached => {
1648 thread.tool_use_limit_reached = true;
1649 }
1650 }
1651 }
1652 }
1653 }
1654
1655 thread.touch_updated_at();
1656 cx.emit(ThreadEvent::StreamedCompletion);
1657 cx.notify();
1658
1659 thread.auto_capture_telemetry(cx);
1660 Ok(())
1661 })??;
1662
1663 smol::future::yield_now().await;
1664 }
1665
1666 thread.update(cx, |thread, cx| {
1667 thread.last_received_chunk_at = None;
1668 thread
1669 .pending_completions
1670 .retain(|completion| completion.id != pending_completion_id);
1671
1672 // If there is a response without tool use, summarize the message. Otherwise,
1673 // allow two tool uses before summarizing.
1674 if matches!(thread.summary, ThreadSummary::Pending)
1675 && thread.messages.len() >= 2
1676 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1677 {
1678 thread.summarize(cx);
1679 }
1680 })?;
1681
1682 anyhow::Ok(stop_reason)
1683 };
1684
1685 let result = stream_completion.await;
1686
1687 thread
1688 .update(cx, |thread, cx| {
1689 thread.finalize_pending_checkpoint(cx);
1690 match result.as_ref() {
1691 Ok(stop_reason) => match stop_reason {
1692 StopReason::ToolUse => {
1693 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1694 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1695 }
1696 StopReason::EndTurn | StopReason::MaxTokens => {
1697 thread.project.update(cx, |project, cx| {
1698 project.set_agent_location(None, cx);
1699 });
1700 }
1701 StopReason::Refusal => {
1702 thread.project.update(cx, |project, cx| {
1703 project.set_agent_location(None, cx);
1704 });
1705
1706 // Remove the turn that was refused.
1707 //
1708 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1709 {
1710 let mut messages_to_remove = Vec::new();
1711
1712 for (ix, message) in thread.messages.iter().enumerate().rev() {
1713 messages_to_remove.push(message.id);
1714
1715 if message.role == Role::User {
1716 if ix == 0 {
1717 break;
1718 }
1719
1720 if let Some(prev_message) = thread.messages.get(ix - 1) {
1721 if prev_message.role == Role::Assistant {
1722 break;
1723 }
1724 }
1725 }
1726 }
1727
1728 for message_id in messages_to_remove {
1729 thread.delete_message(message_id, cx);
1730 }
1731 }
1732
1733 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1734 header: "Language model refusal".into(),
1735 message: "Model refused to generate content for safety reasons.".into(),
1736 }));
1737 }
1738 },
1739 Err(error) => {
1740 thread.project.update(cx, |project, cx| {
1741 project.set_agent_location(None, cx);
1742 });
1743
1744 if error.is::<PaymentRequiredError>() {
1745 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1746 } else if let Some(error) =
1747 error.downcast_ref::<ModelRequestLimitReachedError>()
1748 {
1749 cx.emit(ThreadEvent::ShowError(
1750 ThreadError::ModelRequestLimitReached { plan: error.plan },
1751 ));
1752 } else if let Some(known_error) =
1753 error.downcast_ref::<LanguageModelKnownError>()
1754 {
1755 match known_error {
1756 LanguageModelKnownError::ContextWindowLimitExceeded {
1757 tokens,
1758 } => {
1759 thread.exceeded_window_error = Some(ExceededWindowError {
1760 model_id: model.id(),
1761 token_count: *tokens,
1762 });
1763 cx.notify();
1764 }
1765 }
1766 } else {
1767 let error_message = error
1768 .chain()
1769 .map(|err| err.to_string())
1770 .collect::<Vec<_>>()
1771 .join("\n");
1772 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1773 header: "Error interacting with language model".into(),
1774 message: SharedString::from(error_message.clone()),
1775 }));
1776 }
1777
1778 thread.cancel_last_completion(window, cx);
1779 }
1780 }
1781 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1782
1783 if let Some((request_callback, (request, response_events))) = thread
1784 .request_callback
1785 .as_mut()
1786 .zip(request_callback_parameters.as_ref())
1787 {
1788 request_callback(request, response_events);
1789 }
1790
1791 thread.auto_capture_telemetry(cx);
1792
1793 if let Ok(initial_usage) = initial_token_usage {
1794 let usage = thread.cumulative_token_usage - initial_usage;
1795
1796 telemetry::event!(
1797 "Assistant Thread Completion",
1798 thread_id = thread.id().to_string(),
1799 prompt_id = prompt_id,
1800 model = model.telemetry_id(),
1801 model_provider = model.provider_id().to_string(),
1802 input_tokens = usage.input_tokens,
1803 output_tokens = usage.output_tokens,
1804 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1805 cache_read_input_tokens = usage.cache_read_input_tokens,
1806 );
1807 }
1808 })
1809 .ok();
1810 });
1811
1812 self.pending_completions.push(PendingCompletion {
1813 id: pending_completion_id,
1814 queue_state: QueueState::Sending,
1815 _task: task,
1816 });
1817 }
1818
1819 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1820 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1821 println!("No thread summary model");
1822 return;
1823 };
1824
1825 if !model.provider.is_authenticated(cx) {
1826 return;
1827 }
1828
1829 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1830 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1831 If the conversation is about a specific subject, include it in the title. \
1832 Be descriptive. DO NOT speak in the first person.";
1833
1834 let request = self.to_summarize_request(
1835 &model.model,
1836 CompletionIntent::ThreadSummarization,
1837 added_user_message.into(),
1838 cx,
1839 );
1840
1841 self.summary = ThreadSummary::Generating;
1842
1843 self.pending_summary = cx.spawn(async move |this, cx| {
1844 let result = async {
1845 let mut messages = model.model.stream_completion(request, &cx).await?;
1846
1847 let mut new_summary = String::new();
1848 while let Some(event) = messages.next().await {
1849 let Ok(event) = event else {
1850 continue;
1851 };
1852 let text = match event {
1853 LanguageModelCompletionEvent::Text(text) => text,
1854 LanguageModelCompletionEvent::StatusUpdate(
1855 CompletionRequestStatus::UsageUpdated { amount, limit },
1856 ) => {
1857 this.update(cx, |thread, _cx| {
1858 thread.last_usage = Some(RequestUsage {
1859 limit,
1860 amount: amount as i32,
1861 });
1862 })?;
1863 continue;
1864 }
1865 _ => continue,
1866 };
1867
1868 let mut lines = text.lines();
1869 new_summary.extend(lines.next());
1870
1871 // Stop if the LLM generated multiple lines.
1872 if lines.next().is_some() {
1873 break;
1874 }
1875 }
1876
1877 anyhow::Ok(new_summary)
1878 }
1879 .await;
1880
1881 this.update(cx, |this, cx| {
1882 match result {
1883 Ok(new_summary) => {
1884 if new_summary.is_empty() {
1885 this.summary = ThreadSummary::Error;
1886 } else {
1887 this.summary = ThreadSummary::Ready(new_summary.into());
1888 }
1889 }
1890 Err(err) => {
1891 this.summary = ThreadSummary::Error;
1892 log::error!("Failed to generate thread summary: {}", err);
1893 }
1894 }
1895 cx.emit(ThreadEvent::SummaryGenerated);
1896 })
1897 .log_err()?;
1898
1899 Some(())
1900 });
1901 }
1902
1903 pub fn start_generating_detailed_summary_if_needed(
1904 &mut self,
1905 thread_store: WeakEntity<ThreadStore>,
1906 cx: &mut Context<Self>,
1907 ) {
1908 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1909 return;
1910 };
1911
1912 match &*self.detailed_summary_rx.borrow() {
1913 DetailedSummaryState::Generating { message_id, .. }
1914 | DetailedSummaryState::Generated { message_id, .. }
1915 if *message_id == last_message_id =>
1916 {
1917 // Already up-to-date
1918 return;
1919 }
1920 _ => {}
1921 }
1922
1923 let Some(ConfiguredModel { model, provider }) =
1924 LanguageModelRegistry::read_global(cx).thread_summary_model()
1925 else {
1926 return;
1927 };
1928
1929 if !provider.is_authenticated(cx) {
1930 return;
1931 }
1932
1933 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1934 1. A brief overview of what was discussed\n\
1935 2. Key facts or information discovered\n\
1936 3. Outcomes or conclusions reached\n\
1937 4. Any action items or next steps if any\n\
1938 Format it in Markdown with headings and bullet points.";
1939
1940 let request = self.to_summarize_request(
1941 &model,
1942 CompletionIntent::ThreadContextSummarization,
1943 added_user_message.into(),
1944 cx,
1945 );
1946
1947 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1948 message_id: last_message_id,
1949 };
1950
1951 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1952 // be better to allow the old task to complete, but this would require logic for choosing
1953 // which result to prefer (the old task could complete after the new one, resulting in a
1954 // stale summary).
1955 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1956 let stream = model.stream_completion_text(request, &cx);
1957 let Some(mut messages) = stream.await.log_err() else {
1958 thread
1959 .update(cx, |thread, _cx| {
1960 *thread.detailed_summary_tx.borrow_mut() =
1961 DetailedSummaryState::NotGenerated;
1962 })
1963 .ok()?;
1964 return None;
1965 };
1966
1967 let mut new_detailed_summary = String::new();
1968
1969 while let Some(chunk) = messages.stream.next().await {
1970 if let Some(chunk) = chunk.log_err() {
1971 new_detailed_summary.push_str(&chunk);
1972 }
1973 }
1974
1975 thread
1976 .update(cx, |thread, _cx| {
1977 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1978 text: new_detailed_summary.into(),
1979 message_id: last_message_id,
1980 };
1981 })
1982 .ok()?;
1983
1984 // Save thread so its summary can be reused later
1985 if let Some(thread) = thread.upgrade() {
1986 if let Ok(Ok(save_task)) = cx.update(|cx| {
1987 thread_store
1988 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1989 }) {
1990 save_task.await.log_err();
1991 }
1992 }
1993
1994 Some(())
1995 });
1996 }
1997
1998 pub async fn wait_for_detailed_summary_or_text(
1999 this: &Entity<Self>,
2000 cx: &mut AsyncApp,
2001 ) -> Option<SharedString> {
2002 let mut detailed_summary_rx = this
2003 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2004 .ok()?;
2005 loop {
2006 match detailed_summary_rx.recv().await? {
2007 DetailedSummaryState::Generating { .. } => {}
2008 DetailedSummaryState::NotGenerated => {
2009 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2010 }
2011 DetailedSummaryState::Generated { text, .. } => return Some(text),
2012 }
2013 }
2014 }
2015
2016 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2017 self.detailed_summary_rx
2018 .borrow()
2019 .text()
2020 .unwrap_or_else(|| self.text().into())
2021 }
2022
2023 pub fn is_generating_detailed_summary(&self) -> bool {
2024 matches!(
2025 &*self.detailed_summary_rx.borrow(),
2026 DetailedSummaryState::Generating { .. }
2027 )
2028 }
2029
2030 pub fn use_pending_tools(
2031 &mut self,
2032 window: Option<AnyWindowHandle>,
2033 cx: &mut Context<Self>,
2034 model: Arc<dyn LanguageModel>,
2035 ) -> Vec<PendingToolUse> {
2036 self.auto_capture_telemetry(cx);
2037 let request =
2038 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2039 let pending_tool_uses = self
2040 .tool_use
2041 .pending_tool_uses()
2042 .into_iter()
2043 .filter(|tool_use| tool_use.status.is_idle())
2044 .cloned()
2045 .collect::<Vec<_>>();
2046
2047 for tool_use in pending_tool_uses.iter() {
2048 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2049 if tool.needs_confirmation(&tool_use.input, cx)
2050 && !AssistantSettings::get_global(cx).always_allow_tool_actions
2051 {
2052 self.tool_use.confirm_tool_use(
2053 tool_use.id.clone(),
2054 tool_use.ui_text.clone(),
2055 tool_use.input.clone(),
2056 request.clone(),
2057 tool,
2058 );
2059 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2060 } else {
2061 self.run_tool(
2062 tool_use.id.clone(),
2063 tool_use.ui_text.clone(),
2064 tool_use.input.clone(),
2065 request.clone(),
2066 tool,
2067 model.clone(),
2068 window,
2069 cx,
2070 );
2071 }
2072 } else {
2073 self.handle_hallucinated_tool_use(
2074 tool_use.id.clone(),
2075 tool_use.name.clone(),
2076 window,
2077 cx,
2078 );
2079 }
2080 }
2081
2082 pending_tool_uses
2083 }
2084
2085 pub fn handle_hallucinated_tool_use(
2086 &mut self,
2087 tool_use_id: LanguageModelToolUseId,
2088 hallucinated_tool_name: Arc<str>,
2089 window: Option<AnyWindowHandle>,
2090 cx: &mut Context<Thread>,
2091 ) {
2092 let available_tools = self.tools.read(cx).enabled_tools(cx);
2093
2094 let tool_list = available_tools
2095 .iter()
2096 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2097 .collect::<Vec<_>>()
2098 .join("\n");
2099
2100 let error_message = format!(
2101 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2102 hallucinated_tool_name, tool_list
2103 );
2104
2105 let pending_tool_use = self.tool_use.insert_tool_output(
2106 tool_use_id.clone(),
2107 hallucinated_tool_name,
2108 Err(anyhow!("Missing tool call: {error_message}")),
2109 self.configured_model.as_ref(),
2110 );
2111
2112 cx.emit(ThreadEvent::MissingToolUse {
2113 tool_use_id: tool_use_id.clone(),
2114 ui_text: error_message.into(),
2115 });
2116
2117 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2118 }
2119
2120 pub fn receive_invalid_tool_json(
2121 &mut self,
2122 tool_use_id: LanguageModelToolUseId,
2123 tool_name: Arc<str>,
2124 invalid_json: Arc<str>,
2125 error: String,
2126 window: Option<AnyWindowHandle>,
2127 cx: &mut Context<Thread>,
2128 ) {
2129 log::error!("The model returned invalid input JSON: {invalid_json}");
2130
2131 let pending_tool_use = self.tool_use.insert_tool_output(
2132 tool_use_id.clone(),
2133 tool_name,
2134 Err(anyhow!("Error parsing input JSON: {error}")),
2135 self.configured_model.as_ref(),
2136 );
2137 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2138 pending_tool_use.ui_text.clone()
2139 } else {
2140 log::error!(
2141 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2142 );
2143 format!("Unknown tool {}", tool_use_id).into()
2144 };
2145
2146 cx.emit(ThreadEvent::InvalidToolInput {
2147 tool_use_id: tool_use_id.clone(),
2148 ui_text,
2149 invalid_input_json: invalid_json,
2150 });
2151
2152 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2153 }
2154
2155 pub fn run_tool(
2156 &mut self,
2157 tool_use_id: LanguageModelToolUseId,
2158 ui_text: impl Into<SharedString>,
2159 input: serde_json::Value,
2160 request: Arc<LanguageModelRequest>,
2161 tool: Arc<dyn Tool>,
2162 model: Arc<dyn LanguageModel>,
2163 window: Option<AnyWindowHandle>,
2164 cx: &mut Context<Thread>,
2165 ) {
2166 let task =
2167 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2168 self.tool_use
2169 .run_pending_tool(tool_use_id, ui_text.into(), task);
2170 }
2171
2172 fn spawn_tool_use(
2173 &mut self,
2174 tool_use_id: LanguageModelToolUseId,
2175 request: Arc<LanguageModelRequest>,
2176 input: serde_json::Value,
2177 tool: Arc<dyn Tool>,
2178 model: Arc<dyn LanguageModel>,
2179 window: Option<AnyWindowHandle>,
2180 cx: &mut Context<Thread>,
2181 ) -> Task<()> {
2182 let tool_name: Arc<str> = tool.name().into();
2183
2184 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2185 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2186 } else {
2187 tool.run(
2188 input,
2189 request,
2190 self.project.clone(),
2191 self.action_log.clone(),
2192 model,
2193 window,
2194 cx,
2195 )
2196 };
2197
2198 // Store the card separately if it exists
2199 if let Some(card) = tool_result.card.clone() {
2200 self.tool_use
2201 .insert_tool_result_card(tool_use_id.clone(), card);
2202 }
2203
2204 cx.spawn({
2205 async move |thread: WeakEntity<Thread>, cx| {
2206 let output = tool_result.output.await;
2207
2208 thread
2209 .update(cx, |thread, cx| {
2210 let pending_tool_use = thread.tool_use.insert_tool_output(
2211 tool_use_id.clone(),
2212 tool_name,
2213 output,
2214 thread.configured_model.as_ref(),
2215 );
2216 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2217 })
2218 .ok();
2219 }
2220 })
2221 }
2222
2223 fn tool_finished(
2224 &mut self,
2225 tool_use_id: LanguageModelToolUseId,
2226 pending_tool_use: Option<PendingToolUse>,
2227 canceled: bool,
2228 window: Option<AnyWindowHandle>,
2229 cx: &mut Context<Self>,
2230 ) {
2231 if self.all_tools_finished() {
2232 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2233 if !canceled {
2234 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2235 }
2236 self.auto_capture_telemetry(cx);
2237 }
2238 }
2239
2240 cx.emit(ThreadEvent::ToolFinished {
2241 tool_use_id,
2242 pending_tool_use,
2243 });
2244 }
2245
2246 /// Cancels the last pending completion, if there are any pending.
2247 ///
2248 /// Returns whether a completion was canceled.
2249 pub fn cancel_last_completion(
2250 &mut self,
2251 window: Option<AnyWindowHandle>,
2252 cx: &mut Context<Self>,
2253 ) -> bool {
2254 let mut canceled = self.pending_completions.pop().is_some();
2255
2256 for pending_tool_use in self.tool_use.cancel_pending() {
2257 canceled = true;
2258 self.tool_finished(
2259 pending_tool_use.id.clone(),
2260 Some(pending_tool_use),
2261 true,
2262 window,
2263 cx,
2264 );
2265 }
2266
2267 self.finalize_pending_checkpoint(cx);
2268
2269 if canceled {
2270 cx.emit(ThreadEvent::CompletionCanceled);
2271 }
2272
2273 canceled
2274 }
2275
2276 /// Signals that any in-progress editing should be canceled.
2277 ///
2278 /// This method is used to notify listeners (like ActiveThread) that
2279 /// they should cancel any editing operations.
2280 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2281 cx.emit(ThreadEvent::CancelEditing);
2282 }
2283
2284 pub fn feedback(&self) -> Option<ThreadFeedback> {
2285 self.feedback
2286 }
2287
2288 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2289 self.message_feedback.get(&message_id).copied()
2290 }
2291
2292 pub fn report_message_feedback(
2293 &mut self,
2294 message_id: MessageId,
2295 feedback: ThreadFeedback,
2296 cx: &mut Context<Self>,
2297 ) -> Task<Result<()>> {
2298 if self.message_feedback.get(&message_id) == Some(&feedback) {
2299 return Task::ready(Ok(()));
2300 }
2301
2302 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2303 let serialized_thread = self.serialize(cx);
2304 let thread_id = self.id().clone();
2305 let client = self.project.read(cx).client();
2306
2307 let enabled_tool_names: Vec<String> = self
2308 .tools()
2309 .read(cx)
2310 .enabled_tools(cx)
2311 .iter()
2312 .map(|tool| tool.name())
2313 .collect();
2314
2315 self.message_feedback.insert(message_id, feedback);
2316
2317 cx.notify();
2318
2319 let message_content = self
2320 .message(message_id)
2321 .map(|msg| msg.to_string())
2322 .unwrap_or_default();
2323
2324 cx.background_spawn(async move {
2325 let final_project_snapshot = final_project_snapshot.await;
2326 let serialized_thread = serialized_thread.await?;
2327 let thread_data =
2328 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2329
2330 let rating = match feedback {
2331 ThreadFeedback::Positive => "positive",
2332 ThreadFeedback::Negative => "negative",
2333 };
2334 telemetry::event!(
2335 "Assistant Thread Rated",
2336 rating,
2337 thread_id,
2338 enabled_tool_names,
2339 message_id = message_id.0,
2340 message_content,
2341 thread_data,
2342 final_project_snapshot
2343 );
2344 client.telemetry().flush_events().await;
2345
2346 Ok(())
2347 })
2348 }
2349
2350 pub fn report_feedback(
2351 &mut self,
2352 feedback: ThreadFeedback,
2353 cx: &mut Context<Self>,
2354 ) -> Task<Result<()>> {
2355 let last_assistant_message_id = self
2356 .messages
2357 .iter()
2358 .rev()
2359 .find(|msg| msg.role == Role::Assistant)
2360 .map(|msg| msg.id);
2361
2362 if let Some(message_id) = last_assistant_message_id {
2363 self.report_message_feedback(message_id, feedback, cx)
2364 } else {
2365 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2366 let serialized_thread = self.serialize(cx);
2367 let thread_id = self.id().clone();
2368 let client = self.project.read(cx).client();
2369 self.feedback = Some(feedback);
2370 cx.notify();
2371
2372 cx.background_spawn(async move {
2373 let final_project_snapshot = final_project_snapshot.await;
2374 let serialized_thread = serialized_thread.await?;
2375 let thread_data = serde_json::to_value(serialized_thread)
2376 .unwrap_or_else(|_| serde_json::Value::Null);
2377
2378 let rating = match feedback {
2379 ThreadFeedback::Positive => "positive",
2380 ThreadFeedback::Negative => "negative",
2381 };
2382 telemetry::event!(
2383 "Assistant Thread Rated",
2384 rating,
2385 thread_id,
2386 thread_data,
2387 final_project_snapshot
2388 );
2389 client.telemetry().flush_events().await;
2390
2391 Ok(())
2392 })
2393 }
2394 }
2395
2396 /// Create a snapshot of the current project state including git information and unsaved buffers.
2397 fn project_snapshot(
2398 project: Entity<Project>,
2399 cx: &mut Context<Self>,
2400 ) -> Task<Arc<ProjectSnapshot>> {
2401 let git_store = project.read(cx).git_store().clone();
2402 let worktree_snapshots: Vec<_> = project
2403 .read(cx)
2404 .visible_worktrees(cx)
2405 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2406 .collect();
2407
2408 cx.spawn(async move |_, cx| {
2409 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2410
2411 let mut unsaved_buffers = Vec::new();
2412 cx.update(|app_cx| {
2413 let buffer_store = project.read(app_cx).buffer_store();
2414 for buffer_handle in buffer_store.read(app_cx).buffers() {
2415 let buffer = buffer_handle.read(app_cx);
2416 if buffer.is_dirty() {
2417 if let Some(file) = buffer.file() {
2418 let path = file.path().to_string_lossy().to_string();
2419 unsaved_buffers.push(path);
2420 }
2421 }
2422 }
2423 })
2424 .ok();
2425
2426 Arc::new(ProjectSnapshot {
2427 worktree_snapshots,
2428 unsaved_buffer_paths: unsaved_buffers,
2429 timestamp: Utc::now(),
2430 })
2431 })
2432 }
2433
2434 fn worktree_snapshot(
2435 worktree: Entity<project::Worktree>,
2436 git_store: Entity<GitStore>,
2437 cx: &App,
2438 ) -> Task<WorktreeSnapshot> {
2439 cx.spawn(async move |cx| {
2440 // Get worktree path and snapshot
2441 let worktree_info = cx.update(|app_cx| {
2442 let worktree = worktree.read(app_cx);
2443 let path = worktree.abs_path().to_string_lossy().to_string();
2444 let snapshot = worktree.snapshot();
2445 (path, snapshot)
2446 });
2447
2448 let Ok((worktree_path, _snapshot)) = worktree_info else {
2449 return WorktreeSnapshot {
2450 worktree_path: String::new(),
2451 git_state: None,
2452 };
2453 };
2454
2455 let git_state = git_store
2456 .update(cx, |git_store, cx| {
2457 git_store
2458 .repositories()
2459 .values()
2460 .find(|repo| {
2461 repo.read(cx)
2462 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2463 .is_some()
2464 })
2465 .cloned()
2466 })
2467 .ok()
2468 .flatten()
2469 .map(|repo| {
2470 repo.update(cx, |repo, _| {
2471 let current_branch =
2472 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2473 repo.send_job(None, |state, _| async move {
2474 let RepositoryState::Local { backend, .. } = state else {
2475 return GitState {
2476 remote_url: None,
2477 head_sha: None,
2478 current_branch,
2479 diff: None,
2480 };
2481 };
2482
2483 let remote_url = backend.remote_url("origin");
2484 let head_sha = backend.head_sha().await;
2485 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2486
2487 GitState {
2488 remote_url,
2489 head_sha,
2490 current_branch,
2491 diff,
2492 }
2493 })
2494 })
2495 });
2496
2497 let git_state = match git_state {
2498 Some(git_state) => match git_state.ok() {
2499 Some(git_state) => git_state.await.ok(),
2500 None => None,
2501 },
2502 None => None,
2503 };
2504
2505 WorktreeSnapshot {
2506 worktree_path,
2507 git_state,
2508 }
2509 })
2510 }
2511
2512 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2513 let mut markdown = Vec::new();
2514
2515 let summary = self.summary().or_default();
2516 writeln!(markdown, "# {summary}\n")?;
2517
2518 for message in self.messages() {
2519 writeln!(
2520 markdown,
2521 "## {role}\n",
2522 role = match message.role {
2523 Role::User => "User",
2524 Role::Assistant => "Agent",
2525 Role::System => "System",
2526 }
2527 )?;
2528
2529 if !message.loaded_context.text.is_empty() {
2530 writeln!(markdown, "{}", message.loaded_context.text)?;
2531 }
2532
2533 if !message.loaded_context.images.is_empty() {
2534 writeln!(
2535 markdown,
2536 "\n{} images attached as context.\n",
2537 message.loaded_context.images.len()
2538 )?;
2539 }
2540
2541 for segment in &message.segments {
2542 match segment {
2543 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2544 MessageSegment::Thinking { text, .. } => {
2545 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2546 }
2547 MessageSegment::RedactedThinking(_) => {}
2548 }
2549 }
2550
2551 for tool_use in self.tool_uses_for_message(message.id, cx) {
2552 writeln!(
2553 markdown,
2554 "**Use Tool: {} ({})**",
2555 tool_use.name, tool_use.id
2556 )?;
2557 writeln!(markdown, "```json")?;
2558 writeln!(
2559 markdown,
2560 "{}",
2561 serde_json::to_string_pretty(&tool_use.input)?
2562 )?;
2563 writeln!(markdown, "```")?;
2564 }
2565
2566 for tool_result in self.tool_results_for_message(message.id) {
2567 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2568 if tool_result.is_error {
2569 write!(markdown, " (Error)")?;
2570 }
2571
2572 writeln!(markdown, "**\n")?;
2573 match &tool_result.content {
2574 LanguageModelToolResultContent::Text(text)
2575 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2576 text,
2577 ..
2578 }) => {
2579 writeln!(markdown, "{text}")?;
2580 }
2581 LanguageModelToolResultContent::Image(image) => {
2582 writeln!(markdown, "", image.source)?;
2583 }
2584 }
2585
2586 if let Some(output) = tool_result.output.as_ref() {
2587 writeln!(
2588 markdown,
2589 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2590 serde_json::to_string_pretty(output)?
2591 )?;
2592 }
2593 }
2594 }
2595
2596 Ok(String::from_utf8_lossy(&markdown).to_string())
2597 }
2598
2599 pub fn keep_edits_in_range(
2600 &mut self,
2601 buffer: Entity<language::Buffer>,
2602 buffer_range: Range<language::Anchor>,
2603 cx: &mut Context<Self>,
2604 ) {
2605 self.action_log.update(cx, |action_log, cx| {
2606 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2607 });
2608 }
2609
2610 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2611 self.action_log
2612 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2613 }
2614
2615 pub fn reject_edits_in_ranges(
2616 &mut self,
2617 buffer: Entity<language::Buffer>,
2618 buffer_ranges: Vec<Range<language::Anchor>>,
2619 cx: &mut Context<Self>,
2620 ) -> Task<Result<()>> {
2621 self.action_log.update(cx, |action_log, cx| {
2622 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2623 })
2624 }
2625
2626 pub fn action_log(&self) -> &Entity<ActionLog> {
2627 &self.action_log
2628 }
2629
2630 pub fn project(&self) -> &Entity<Project> {
2631 &self.project
2632 }
2633
2634 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2635 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2636 return;
2637 }
2638
2639 let now = Instant::now();
2640 if let Some(last) = self.last_auto_capture_at {
2641 if now.duration_since(last).as_secs() < 10 {
2642 return;
2643 }
2644 }
2645
2646 self.last_auto_capture_at = Some(now);
2647
2648 let thread_id = self.id().clone();
2649 let github_login = self
2650 .project
2651 .read(cx)
2652 .user_store()
2653 .read(cx)
2654 .current_user()
2655 .map(|user| user.github_login.clone());
2656 let client = self.project.read(cx).client();
2657 let serialize_task = self.serialize(cx);
2658
2659 cx.background_executor()
2660 .spawn(async move {
2661 if let Ok(serialized_thread) = serialize_task.await {
2662 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2663 telemetry::event!(
2664 "Agent Thread Auto-Captured",
2665 thread_id = thread_id.to_string(),
2666 thread_data = thread_data,
2667 auto_capture_reason = "tracked_user",
2668 github_login = github_login
2669 );
2670
2671 client.telemetry().flush_events().await;
2672 }
2673 }
2674 })
2675 .detach();
2676 }
2677
2678 pub fn cumulative_token_usage(&self) -> TokenUsage {
2679 self.cumulative_token_usage
2680 }
2681
2682 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2683 let Some(model) = self.configured_model.as_ref() else {
2684 return TotalTokenUsage::default();
2685 };
2686
2687 let max = model.model.max_token_count();
2688
2689 let index = self
2690 .messages
2691 .iter()
2692 .position(|msg| msg.id == message_id)
2693 .unwrap_or(0);
2694
2695 if index == 0 {
2696 return TotalTokenUsage { total: 0, max };
2697 }
2698
2699 let token_usage = &self
2700 .request_token_usage
2701 .get(index - 1)
2702 .cloned()
2703 .unwrap_or_default();
2704
2705 TotalTokenUsage {
2706 total: token_usage.total_tokens() as usize,
2707 max,
2708 }
2709 }
2710
2711 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2712 let model = self.configured_model.as_ref()?;
2713
2714 let max = model.model.max_token_count();
2715
2716 if let Some(exceeded_error) = &self.exceeded_window_error {
2717 if model.model.id() == exceeded_error.model_id {
2718 return Some(TotalTokenUsage {
2719 total: exceeded_error.token_count,
2720 max,
2721 });
2722 }
2723 }
2724
2725 let total = self
2726 .token_usage_at_last_message()
2727 .unwrap_or_default()
2728 .total_tokens() as usize;
2729
2730 Some(TotalTokenUsage { total, max })
2731 }
2732
2733 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2734 self.request_token_usage
2735 .get(self.messages.len().saturating_sub(1))
2736 .or_else(|| self.request_token_usage.last())
2737 .cloned()
2738 }
2739
2740 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2741 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2742 self.request_token_usage
2743 .resize(self.messages.len(), placeholder);
2744
2745 if let Some(last) = self.request_token_usage.last_mut() {
2746 *last = token_usage;
2747 }
2748 }
2749
2750 pub fn deny_tool_use(
2751 &mut self,
2752 tool_use_id: LanguageModelToolUseId,
2753 tool_name: Arc<str>,
2754 window: Option<AnyWindowHandle>,
2755 cx: &mut Context<Self>,
2756 ) {
2757 let err = Err(anyhow::anyhow!(
2758 "Permission to run tool action denied by user"
2759 ));
2760
2761 self.tool_use.insert_tool_output(
2762 tool_use_id.clone(),
2763 tool_name,
2764 err,
2765 self.configured_model.as_ref(),
2766 );
2767 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2768 }
2769}
2770
2771#[derive(Debug, Clone, Error)]
2772pub enum ThreadError {
2773 #[error("Payment required")]
2774 PaymentRequired,
2775 #[error("Model request limit reached")]
2776 ModelRequestLimitReached { plan: Plan },
2777 #[error("Message {header}: {message}")]
2778 Message {
2779 header: SharedString,
2780 message: SharedString,
2781 },
2782}
2783
2784#[derive(Debug, Clone)]
2785pub enum ThreadEvent {
2786 ShowError(ThreadError),
2787 StreamedCompletion,
2788 ReceivedTextChunk,
2789 NewRequest,
2790 StreamedAssistantText(MessageId, String),
2791 StreamedAssistantThinking(MessageId, String),
2792 StreamedToolUse {
2793 tool_use_id: LanguageModelToolUseId,
2794 ui_text: Arc<str>,
2795 input: serde_json::Value,
2796 },
2797 MissingToolUse {
2798 tool_use_id: LanguageModelToolUseId,
2799 ui_text: Arc<str>,
2800 },
2801 InvalidToolInput {
2802 tool_use_id: LanguageModelToolUseId,
2803 ui_text: Arc<str>,
2804 invalid_input_json: Arc<str>,
2805 },
2806 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2807 MessageAdded(MessageId),
2808 MessageEdited(MessageId),
2809 MessageDeleted(MessageId),
2810 SummaryGenerated,
2811 SummaryChanged,
2812 UsePendingTools {
2813 tool_uses: Vec<PendingToolUse>,
2814 },
2815 ToolFinished {
2816 #[allow(unused)]
2817 tool_use_id: LanguageModelToolUseId,
2818 /// The pending tool use that corresponds to this tool.
2819 pending_tool_use: Option<PendingToolUse>,
2820 },
2821 CheckpointChanged,
2822 ToolConfirmationNeeded,
2823 CancelEditing,
2824 CompletionCanceled,
2825}
2826
2827impl EventEmitter<ThreadEvent> for Thread {}
2828
2829struct PendingCompletion {
2830 id: usize,
2831 queue_state: QueueState,
2832 _task: Task<()>,
2833}
2834
2835#[cfg(test)]
2836mod tests {
2837 use super::*;
2838 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2839 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2840 use assistant_tool::ToolRegistry;
2841 use editor::EditorSettings;
2842 use gpui::TestAppContext;
2843 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2844 use project::{FakeFs, Project};
2845 use prompt_store::PromptBuilder;
2846 use serde_json::json;
2847 use settings::{Settings, SettingsStore};
2848 use std::sync::Arc;
2849 use theme::ThemeSettings;
2850 use util::path;
2851 use workspace::Workspace;
2852
2853 #[gpui::test]
2854 async fn test_message_with_context(cx: &mut TestAppContext) {
2855 init_test_settings(cx);
2856
2857 let project = create_test_project(
2858 cx,
2859 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2860 )
2861 .await;
2862
2863 let (_workspace, _thread_store, thread, context_store, model) =
2864 setup_test_environment(cx, project.clone()).await;
2865
2866 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2867 .await
2868 .unwrap();
2869
2870 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2871 let loaded_context = cx
2872 .update(|cx| load_context(vec![context], &project, &None, cx))
2873 .await;
2874
2875 // Insert user message with context
2876 let message_id = thread.update(cx, |thread, cx| {
2877 thread.insert_user_message(
2878 "Please explain this code",
2879 loaded_context,
2880 None,
2881 Vec::new(),
2882 cx,
2883 )
2884 });
2885
2886 // Check content and context in message object
2887 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2888
2889 // Use different path format strings based on platform for the test
2890 #[cfg(windows)]
2891 let path_part = r"test\code.rs";
2892 #[cfg(not(windows))]
2893 let path_part = "test/code.rs";
2894
2895 let expected_context = format!(
2896 r#"
2897<context>
2898The following items were attached by the user. They are up-to-date and don't need to be re-read.
2899
2900<files>
2901```rs {path_part}
2902fn main() {{
2903 println!("Hello, world!");
2904}}
2905```
2906</files>
2907</context>
2908"#
2909 );
2910
2911 assert_eq!(message.role, Role::User);
2912 assert_eq!(message.segments.len(), 1);
2913 assert_eq!(
2914 message.segments[0],
2915 MessageSegment::Text("Please explain this code".to_string())
2916 );
2917 assert_eq!(message.loaded_context.text, expected_context);
2918
2919 // Check message in request
2920 let request = thread.update(cx, |thread, cx| {
2921 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2922 });
2923
2924 assert_eq!(request.messages.len(), 2);
2925 let expected_full_message = format!("{}Please explain this code", expected_context);
2926 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2927 }
2928
2929 #[gpui::test]
2930 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2931 init_test_settings(cx);
2932
2933 let project = create_test_project(
2934 cx,
2935 json!({
2936 "file1.rs": "fn function1() {}\n",
2937 "file2.rs": "fn function2() {}\n",
2938 "file3.rs": "fn function3() {}\n",
2939 "file4.rs": "fn function4() {}\n",
2940 }),
2941 )
2942 .await;
2943
2944 let (_, _thread_store, thread, context_store, model) =
2945 setup_test_environment(cx, project.clone()).await;
2946
2947 // First message with context 1
2948 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2949 .await
2950 .unwrap();
2951 let new_contexts = context_store.update(cx, |store, cx| {
2952 store.new_context_for_thread(thread.read(cx), None)
2953 });
2954 assert_eq!(new_contexts.len(), 1);
2955 let loaded_context = cx
2956 .update(|cx| load_context(new_contexts, &project, &None, cx))
2957 .await;
2958 let message1_id = thread.update(cx, |thread, cx| {
2959 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2960 });
2961
2962 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2963 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2964 .await
2965 .unwrap();
2966 let new_contexts = context_store.update(cx, |store, cx| {
2967 store.new_context_for_thread(thread.read(cx), None)
2968 });
2969 assert_eq!(new_contexts.len(), 1);
2970 let loaded_context = cx
2971 .update(|cx| load_context(new_contexts, &project, &None, cx))
2972 .await;
2973 let message2_id = thread.update(cx, |thread, cx| {
2974 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2975 });
2976
2977 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2978 //
2979 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2980 .await
2981 .unwrap();
2982 let new_contexts = context_store.update(cx, |store, cx| {
2983 store.new_context_for_thread(thread.read(cx), None)
2984 });
2985 assert_eq!(new_contexts.len(), 1);
2986 let loaded_context = cx
2987 .update(|cx| load_context(new_contexts, &project, &None, cx))
2988 .await;
2989 let message3_id = thread.update(cx, |thread, cx| {
2990 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2991 });
2992
2993 // Check what contexts are included in each message
2994 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2995 (
2996 thread.message(message1_id).unwrap().clone(),
2997 thread.message(message2_id).unwrap().clone(),
2998 thread.message(message3_id).unwrap().clone(),
2999 )
3000 });
3001
3002 // First message should include context 1
3003 assert!(message1.loaded_context.text.contains("file1.rs"));
3004
3005 // Second message should include only context 2 (not 1)
3006 assert!(!message2.loaded_context.text.contains("file1.rs"));
3007 assert!(message2.loaded_context.text.contains("file2.rs"));
3008
3009 // Third message should include only context 3 (not 1 or 2)
3010 assert!(!message3.loaded_context.text.contains("file1.rs"));
3011 assert!(!message3.loaded_context.text.contains("file2.rs"));
3012 assert!(message3.loaded_context.text.contains("file3.rs"));
3013
3014 // Check entire request to make sure all contexts are properly included
3015 let request = thread.update(cx, |thread, cx| {
3016 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3017 });
3018
3019 // The request should contain all 3 messages
3020 assert_eq!(request.messages.len(), 4);
3021
3022 // Check that the contexts are properly formatted in each message
3023 assert!(request.messages[1].string_contents().contains("file1.rs"));
3024 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3025 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3026
3027 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3028 assert!(request.messages[2].string_contents().contains("file2.rs"));
3029 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3030
3031 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3032 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3033 assert!(request.messages[3].string_contents().contains("file3.rs"));
3034
3035 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3036 .await
3037 .unwrap();
3038 let new_contexts = context_store.update(cx, |store, cx| {
3039 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3040 });
3041 assert_eq!(new_contexts.len(), 3);
3042 let loaded_context = cx
3043 .update(|cx| load_context(new_contexts, &project, &None, cx))
3044 .await
3045 .loaded_context;
3046
3047 assert!(!loaded_context.text.contains("file1.rs"));
3048 assert!(loaded_context.text.contains("file2.rs"));
3049 assert!(loaded_context.text.contains("file3.rs"));
3050 assert!(loaded_context.text.contains("file4.rs"));
3051
3052 let new_contexts = context_store.update(cx, |store, cx| {
3053 // Remove file4.rs
3054 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3055 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3056 });
3057 assert_eq!(new_contexts.len(), 2);
3058 let loaded_context = cx
3059 .update(|cx| load_context(new_contexts, &project, &None, cx))
3060 .await
3061 .loaded_context;
3062
3063 assert!(!loaded_context.text.contains("file1.rs"));
3064 assert!(loaded_context.text.contains("file2.rs"));
3065 assert!(loaded_context.text.contains("file3.rs"));
3066 assert!(!loaded_context.text.contains("file4.rs"));
3067
3068 let new_contexts = context_store.update(cx, |store, cx| {
3069 // Remove file3.rs
3070 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3071 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3072 });
3073 assert_eq!(new_contexts.len(), 1);
3074 let loaded_context = cx
3075 .update(|cx| load_context(new_contexts, &project, &None, cx))
3076 .await
3077 .loaded_context;
3078
3079 assert!(!loaded_context.text.contains("file1.rs"));
3080 assert!(loaded_context.text.contains("file2.rs"));
3081 assert!(!loaded_context.text.contains("file3.rs"));
3082 assert!(!loaded_context.text.contains("file4.rs"));
3083 }
3084
3085 #[gpui::test]
3086 async fn test_message_without_files(cx: &mut TestAppContext) {
3087 init_test_settings(cx);
3088
3089 let project = create_test_project(
3090 cx,
3091 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3092 )
3093 .await;
3094
3095 let (_, _thread_store, thread, _context_store, model) =
3096 setup_test_environment(cx, project.clone()).await;
3097
3098 // Insert user message without any context (empty context vector)
3099 let message_id = thread.update(cx, |thread, cx| {
3100 thread.insert_user_message(
3101 "What is the best way to learn Rust?",
3102 ContextLoadResult::default(),
3103 None,
3104 Vec::new(),
3105 cx,
3106 )
3107 });
3108
3109 // Check content and context in message object
3110 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3111
3112 // Context should be empty when no files are included
3113 assert_eq!(message.role, Role::User);
3114 assert_eq!(message.segments.len(), 1);
3115 assert_eq!(
3116 message.segments[0],
3117 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3118 );
3119 assert_eq!(message.loaded_context.text, "");
3120
3121 // Check message in request
3122 let request = thread.update(cx, |thread, cx| {
3123 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3124 });
3125
3126 assert_eq!(request.messages.len(), 2);
3127 assert_eq!(
3128 request.messages[1].string_contents(),
3129 "What is the best way to learn Rust?"
3130 );
3131
3132 // Add second message, also without context
3133 let message2_id = thread.update(cx, |thread, cx| {
3134 thread.insert_user_message(
3135 "Are there any good books?",
3136 ContextLoadResult::default(),
3137 None,
3138 Vec::new(),
3139 cx,
3140 )
3141 });
3142
3143 let message2 =
3144 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3145 assert_eq!(message2.loaded_context.text, "");
3146
3147 // Check that both messages appear in the request
3148 let request = thread.update(cx, |thread, cx| {
3149 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3150 });
3151
3152 assert_eq!(request.messages.len(), 3);
3153 assert_eq!(
3154 request.messages[1].string_contents(),
3155 "What is the best way to learn Rust?"
3156 );
3157 assert_eq!(
3158 request.messages[2].string_contents(),
3159 "Are there any good books?"
3160 );
3161 }
3162
3163 #[gpui::test]
3164 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3165 init_test_settings(cx);
3166
3167 let project = create_test_project(
3168 cx,
3169 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3170 )
3171 .await;
3172
3173 let (_workspace, _thread_store, thread, context_store, model) =
3174 setup_test_environment(cx, project.clone()).await;
3175
3176 // Open buffer and add it to context
3177 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3178 .await
3179 .unwrap();
3180
3181 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3182 let loaded_context = cx
3183 .update(|cx| load_context(vec![context], &project, &None, cx))
3184 .await;
3185
3186 // Insert user message with the buffer as context
3187 thread.update(cx, |thread, cx| {
3188 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3189 });
3190
3191 // Create a request and check that it doesn't have a stale buffer warning yet
3192 let initial_request = thread.update(cx, |thread, cx| {
3193 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3194 });
3195
3196 // Make sure we don't have a stale file warning yet
3197 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3198 msg.string_contents()
3199 .contains("These files changed since last read:")
3200 });
3201 assert!(
3202 !has_stale_warning,
3203 "Should not have stale buffer warning before buffer is modified"
3204 );
3205
3206 // Modify the buffer
3207 buffer.update(cx, |buffer, cx| {
3208 // Find a position at the end of line 1
3209 buffer.edit(
3210 [(1..1, "\n println!(\"Added a new line\");\n")],
3211 None,
3212 cx,
3213 );
3214 });
3215
3216 // Insert another user message without context
3217 thread.update(cx, |thread, cx| {
3218 thread.insert_user_message(
3219 "What does the code do now?",
3220 ContextLoadResult::default(),
3221 None,
3222 Vec::new(),
3223 cx,
3224 )
3225 });
3226
3227 // Create a new request and check for the stale buffer warning
3228 let new_request = thread.update(cx, |thread, cx| {
3229 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3230 });
3231
3232 // We should have a stale file warning as the last message
3233 let last_message = new_request
3234 .messages
3235 .last()
3236 .expect("Request should have messages");
3237
3238 // The last message should be the stale buffer notification
3239 assert_eq!(last_message.role, Role::User);
3240
3241 // Check the exact content of the message
3242 let expected_content = "These files changed since last read:\n- code.rs\n";
3243 assert_eq!(
3244 last_message.string_contents(),
3245 expected_content,
3246 "Last message should be exactly the stale buffer notification"
3247 );
3248 }
3249
3250 #[gpui::test]
3251 async fn test_temperature_setting(cx: &mut TestAppContext) {
3252 init_test_settings(cx);
3253
3254 let project = create_test_project(
3255 cx,
3256 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3257 )
3258 .await;
3259
3260 let (_workspace, _thread_store, thread, _context_store, model) =
3261 setup_test_environment(cx, project.clone()).await;
3262
3263 // Both model and provider
3264 cx.update(|cx| {
3265 AssistantSettings::override_global(
3266 AssistantSettings {
3267 model_parameters: vec![LanguageModelParameters {
3268 provider: Some(model.provider_id().0.to_string().into()),
3269 model: Some(model.id().0.clone()),
3270 temperature: Some(0.66),
3271 }],
3272 ..AssistantSettings::get_global(cx).clone()
3273 },
3274 cx,
3275 );
3276 });
3277
3278 let request = thread.update(cx, |thread, cx| {
3279 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3280 });
3281 assert_eq!(request.temperature, Some(0.66));
3282
3283 // Only model
3284 cx.update(|cx| {
3285 AssistantSettings::override_global(
3286 AssistantSettings {
3287 model_parameters: vec![LanguageModelParameters {
3288 provider: None,
3289 model: Some(model.id().0.clone()),
3290 temperature: Some(0.66),
3291 }],
3292 ..AssistantSettings::get_global(cx).clone()
3293 },
3294 cx,
3295 );
3296 });
3297
3298 let request = thread.update(cx, |thread, cx| {
3299 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3300 });
3301 assert_eq!(request.temperature, Some(0.66));
3302
3303 // Only provider
3304 cx.update(|cx| {
3305 AssistantSettings::override_global(
3306 AssistantSettings {
3307 model_parameters: vec![LanguageModelParameters {
3308 provider: Some(model.provider_id().0.to_string().into()),
3309 model: None,
3310 temperature: Some(0.66),
3311 }],
3312 ..AssistantSettings::get_global(cx).clone()
3313 },
3314 cx,
3315 );
3316 });
3317
3318 let request = thread.update(cx, |thread, cx| {
3319 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3320 });
3321 assert_eq!(request.temperature, Some(0.66));
3322
3323 // Same model name, different provider
3324 cx.update(|cx| {
3325 AssistantSettings::override_global(
3326 AssistantSettings {
3327 model_parameters: vec![LanguageModelParameters {
3328 provider: Some("anthropic".into()),
3329 model: Some(model.id().0.clone()),
3330 temperature: Some(0.66),
3331 }],
3332 ..AssistantSettings::get_global(cx).clone()
3333 },
3334 cx,
3335 );
3336 });
3337
3338 let request = thread.update(cx, |thread, cx| {
3339 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3340 });
3341 assert_eq!(request.temperature, None);
3342 }
3343
3344 #[gpui::test]
3345 async fn test_thread_summary(cx: &mut TestAppContext) {
3346 init_test_settings(cx);
3347
3348 let project = create_test_project(cx, json!({})).await;
3349
3350 let (_, _thread_store, thread, _context_store, model) =
3351 setup_test_environment(cx, project.clone()).await;
3352
3353 // Initial state should be pending
3354 thread.read_with(cx, |thread, _| {
3355 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3356 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3357 });
3358
3359 // Manually setting the summary should not be allowed in this state
3360 thread.update(cx, |thread, cx| {
3361 thread.set_summary("This should not work", cx);
3362 });
3363
3364 thread.read_with(cx, |thread, _| {
3365 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3366 });
3367
3368 // Send a message
3369 thread.update(cx, |thread, cx| {
3370 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3371 thread.send_to_model(
3372 model.clone(),
3373 CompletionIntent::ThreadSummarization,
3374 None,
3375 cx,
3376 );
3377 });
3378
3379 let fake_model = model.as_fake();
3380 simulate_successful_response(&fake_model, cx);
3381
3382 // Should start generating summary when there are >= 2 messages
3383 thread.read_with(cx, |thread, _| {
3384 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3385 });
3386
3387 // Should not be able to set the summary while generating
3388 thread.update(cx, |thread, cx| {
3389 thread.set_summary("This should not work either", cx);
3390 });
3391
3392 thread.read_with(cx, |thread, _| {
3393 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3394 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3395 });
3396
3397 cx.run_until_parked();
3398 fake_model.stream_last_completion_response("Brief".into());
3399 fake_model.stream_last_completion_response(" Introduction".into());
3400 fake_model.end_last_completion_stream();
3401 cx.run_until_parked();
3402
3403 // Summary should be set
3404 thread.read_with(cx, |thread, _| {
3405 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3406 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3407 });
3408
3409 // Now we should be able to set a summary
3410 thread.update(cx, |thread, cx| {
3411 thread.set_summary("Brief Intro", cx);
3412 });
3413
3414 thread.read_with(cx, |thread, _| {
3415 assert_eq!(thread.summary().or_default(), "Brief Intro");
3416 });
3417
3418 // Test setting an empty summary (should default to DEFAULT)
3419 thread.update(cx, |thread, cx| {
3420 thread.set_summary("", cx);
3421 });
3422
3423 thread.read_with(cx, |thread, _| {
3424 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3425 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3426 });
3427 }
3428
3429 #[gpui::test]
3430 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3431 init_test_settings(cx);
3432
3433 let project = create_test_project(cx, json!({})).await;
3434
3435 let (_, _thread_store, thread, _context_store, model) =
3436 setup_test_environment(cx, project.clone()).await;
3437
3438 test_summarize_error(&model, &thread, cx);
3439
3440 // Now we should be able to set a summary
3441 thread.update(cx, |thread, cx| {
3442 thread.set_summary("Brief Intro", cx);
3443 });
3444
3445 thread.read_with(cx, |thread, _| {
3446 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3447 assert_eq!(thread.summary().or_default(), "Brief Intro");
3448 });
3449 }
3450
3451 #[gpui::test]
3452 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3453 init_test_settings(cx);
3454
3455 let project = create_test_project(cx, json!({})).await;
3456
3457 let (_, _thread_store, thread, _context_store, model) =
3458 setup_test_environment(cx, project.clone()).await;
3459
3460 test_summarize_error(&model, &thread, cx);
3461
3462 // Sending another message should not trigger another summarize request
3463 thread.update(cx, |thread, cx| {
3464 thread.insert_user_message(
3465 "How are you?",
3466 ContextLoadResult::default(),
3467 None,
3468 vec![],
3469 cx,
3470 );
3471 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3472 });
3473
3474 let fake_model = model.as_fake();
3475 simulate_successful_response(&fake_model, cx);
3476
3477 thread.read_with(cx, |thread, _| {
3478 // State is still Error, not Generating
3479 assert!(matches!(thread.summary(), ThreadSummary::Error));
3480 });
3481
3482 // But the summarize request can be invoked manually
3483 thread.update(cx, |thread, cx| {
3484 thread.summarize(cx);
3485 });
3486
3487 thread.read_with(cx, |thread, _| {
3488 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3489 });
3490
3491 cx.run_until_parked();
3492 fake_model.stream_last_completion_response("A successful summary".into());
3493 fake_model.end_last_completion_stream();
3494 cx.run_until_parked();
3495
3496 thread.read_with(cx, |thread, _| {
3497 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3498 assert_eq!(thread.summary().or_default(), "A successful summary");
3499 });
3500 }
3501
3502 fn test_summarize_error(
3503 model: &Arc<dyn LanguageModel>,
3504 thread: &Entity<Thread>,
3505 cx: &mut TestAppContext,
3506 ) {
3507 thread.update(cx, |thread, cx| {
3508 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3509 thread.send_to_model(
3510 model.clone(),
3511 CompletionIntent::ThreadSummarization,
3512 None,
3513 cx,
3514 );
3515 });
3516
3517 let fake_model = model.as_fake();
3518 simulate_successful_response(&fake_model, cx);
3519
3520 thread.read_with(cx, |thread, _| {
3521 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3522 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3523 });
3524
3525 // Simulate summary request ending
3526 cx.run_until_parked();
3527 fake_model.end_last_completion_stream();
3528 cx.run_until_parked();
3529
3530 // State is set to Error and default message
3531 thread.read_with(cx, |thread, _| {
3532 assert!(matches!(thread.summary(), ThreadSummary::Error));
3533 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3534 });
3535 }
3536
3537 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3538 cx.run_until_parked();
3539 fake_model.stream_last_completion_response("Assistant response".into());
3540 fake_model.end_last_completion_stream();
3541 cx.run_until_parked();
3542 }
3543
3544 fn init_test_settings(cx: &mut TestAppContext) {
3545 cx.update(|cx| {
3546 let settings_store = SettingsStore::test(cx);
3547 cx.set_global(settings_store);
3548 language::init(cx);
3549 Project::init_settings(cx);
3550 AssistantSettings::register(cx);
3551 prompt_store::init(cx);
3552 thread_store::init(cx);
3553 workspace::init_settings(cx);
3554 language_model::init_settings(cx);
3555 ThemeSettings::register(cx);
3556 EditorSettings::register(cx);
3557 ToolRegistry::default_global(cx);
3558 });
3559 }
3560
3561 // Helper to create a test project with test files
3562 async fn create_test_project(
3563 cx: &mut TestAppContext,
3564 files: serde_json::Value,
3565 ) -> Entity<Project> {
3566 let fs = FakeFs::new(cx.executor());
3567 fs.insert_tree(path!("/test"), files).await;
3568 Project::test(fs, [path!("/test").as_ref()], cx).await
3569 }
3570
3571 async fn setup_test_environment(
3572 cx: &mut TestAppContext,
3573 project: Entity<Project>,
3574 ) -> (
3575 Entity<Workspace>,
3576 Entity<ThreadStore>,
3577 Entity<Thread>,
3578 Entity<ContextStore>,
3579 Arc<dyn LanguageModel>,
3580 ) {
3581 let (workspace, cx) =
3582 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3583
3584 let thread_store = cx
3585 .update(|_, cx| {
3586 ThreadStore::load(
3587 project.clone(),
3588 cx.new(|_| ToolWorkingSet::default()),
3589 None,
3590 Arc::new(PromptBuilder::new(None).unwrap()),
3591 cx,
3592 )
3593 })
3594 .await
3595 .unwrap();
3596
3597 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3598 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3599
3600 let provider = Arc::new(FakeLanguageModelProvider);
3601 let model = provider.test_model();
3602 let model: Arc<dyn LanguageModel> = Arc::new(model);
3603
3604 cx.update(|_, cx| {
3605 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3606 registry.set_default_model(
3607 Some(ConfiguredModel {
3608 provider: provider.clone(),
3609 model: model.clone(),
3610 }),
3611 cx,
3612 );
3613 registry.set_thread_summary_model(
3614 Some(ConfiguredModel {
3615 provider,
3616 model: model.clone(),
3617 }),
3618 cx,
3619 );
3620 })
3621 });
3622
3623 (workspace, thread_store, thread, context_store, model)
3624 }
3625
3626 async fn add_file_to_context(
3627 project: &Entity<Project>,
3628 context_store: &Entity<ContextStore>,
3629 path: &str,
3630 cx: &mut TestAppContext,
3631 ) -> Result<Entity<language::Buffer>> {
3632 let buffer_path = project
3633 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3634 .unwrap();
3635
3636 let buffer = project
3637 .update(cx, |project, cx| {
3638 project.open_buffer(buffer_path.clone(), cx)
3639 })
3640 .await
3641 .unwrap();
3642
3643 context_store.update(cx, |context_store, cx| {
3644 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3645 });
3646
3647 Ok(buffer)
3648 }
3649}