1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5
6use agent_rules::load_worktree_rules_file;
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use fs::Fs;
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
19 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
20 LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
21 PaymentRequiredError, Role, StopReason, TokenUsage,
22};
23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
24use project::{Project, Worktree};
25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use settings::Settings;
29use util::{ResultExt as _, TryFutureExt as _, post_inc};
30use uuid::Uuid;
31
32use crate::context::{AssistantContext, ContextId, format_context_as_string};
33use crate::thread_store::{
34 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
35 SerializedToolUse,
36};
37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
38
39#[derive(Debug, Clone, Copy)]
40pub enum RequestKind {
41 Chat,
42 /// Used when summarizing a thread.
43 Summarize,
44}
45
46#[derive(
47 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
48)]
49pub struct ThreadId(Arc<str>);
50
51impl ThreadId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for ThreadId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63impl From<&str> for ThreadId {
64 fn from(value: &str) -> Self {
65 Self(value.into())
66 }
67}
68
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
70pub struct MessageId(pub(crate) usize);
71
72impl MessageId {
73 fn post_inc(&mut self) -> Self {
74 Self(post_inc(&mut self.0))
75 }
76}
77
78/// A message in a [`Thread`].
79#[derive(Debug, Clone)]
80pub struct Message {
81 pub id: MessageId,
82 pub role: Role,
83 pub segments: Vec<MessageSegment>,
84 pub context: String,
85}
86
87impl Message {
88 /// Returns whether the message contains any meaningful text that should be displayed
89 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
90 pub fn should_display_content(&self) -> bool {
91 self.segments.iter().all(|segment| segment.should_display())
92 }
93
94 pub fn push_thinking(&mut self, text: &str) {
95 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
96 segment.push_str(text);
97 } else {
98 self.segments
99 .push(MessageSegment::Thinking(text.to_string()));
100 }
101 }
102
103 pub fn push_text(&mut self, text: &str) {
104 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
105 segment.push_str(text);
106 } else {
107 self.segments.push(MessageSegment::Text(text.to_string()));
108 }
109 }
110
111 pub fn to_string(&self) -> String {
112 let mut result = String::new();
113
114 if !self.context.is_empty() {
115 result.push_str(&self.context);
116 }
117
118 for segment in &self.segments {
119 match segment {
120 MessageSegment::Text(text) => result.push_str(text),
121 MessageSegment::Thinking(text) => {
122 result.push_str("<think>");
123 result.push_str(text);
124 result.push_str("</think>");
125 }
126 }
127 }
128
129 result
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum MessageSegment {
135 Text(String),
136 Thinking(String),
137}
138
139impl MessageSegment {
140 pub fn text_mut(&mut self) -> &mut String {
141 match self {
142 Self::Text(text) => text,
143 Self::Thinking(text) => text,
144 }
145 }
146
147 pub fn should_display(&self) -> bool {
148 // We add USING_TOOL_MARKER when making a request that includes tool uses
149 // without non-whitespace text around them, and this can cause the model
150 // to mimic the pattern, so we consider those segments not displayable.
151 match self {
152 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
153 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ProjectSnapshot {
160 pub worktree_snapshots: Vec<WorktreeSnapshot>,
161 pub unsaved_buffer_paths: Vec<String>,
162 pub timestamp: DateTime<Utc>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct WorktreeSnapshot {
167 pub worktree_path: String,
168 pub git_state: Option<GitState>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct GitState {
173 pub remote_url: Option<String>,
174 pub head_sha: Option<String>,
175 pub current_branch: Option<String>,
176 pub diff: Option<String>,
177}
178
179#[derive(Clone)]
180pub struct ThreadCheckpoint {
181 message_id: MessageId,
182 git_checkpoint: GitStoreCheckpoint,
183}
184
185#[derive(Copy, Clone, Debug)]
186pub enum ThreadFeedback {
187 Positive,
188 Negative,
189}
190
191pub enum LastRestoreCheckpoint {
192 Pending {
193 message_id: MessageId,
194 },
195 Error {
196 message_id: MessageId,
197 error: String,
198 },
199}
200
201impl LastRestoreCheckpoint {
202 pub fn message_id(&self) -> MessageId {
203 match self {
204 LastRestoreCheckpoint::Pending { message_id } => *message_id,
205 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
206 }
207 }
208}
209
210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
211pub enum DetailedSummaryState {
212 #[default]
213 NotGenerated,
214 Generating {
215 message_id: MessageId,
216 },
217 Generated {
218 text: SharedString,
219 message_id: MessageId,
220 },
221}
222
223#[derive(Default)]
224pub struct TotalTokenUsage {
225 pub total: usize,
226 pub max: usize,
227 pub ratio: TokenUsageRatio,
228}
229
230#[derive(Default, PartialEq, Eq)]
231pub enum TokenUsageRatio {
232 #[default]
233 Normal,
234 Warning,
235 Exceeded,
236}
237
238/// A thread of conversation with the LLM.
239pub struct Thread {
240 id: ThreadId,
241 updated_at: DateTime<Utc>,
242 summary: Option<SharedString>,
243 pending_summary: Task<Option<()>>,
244 detailed_summary_state: DetailedSummaryState,
245 messages: Vec<Message>,
246 next_message_id: MessageId,
247 context: BTreeMap<ContextId, AssistantContext>,
248 context_by_message: HashMap<MessageId, Vec<ContextId>>,
249 system_prompt_context: Option<AssistantSystemPromptContext>,
250 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
251 completion_count: usize,
252 pending_completions: Vec<PendingCompletion>,
253 project: Entity<Project>,
254 prompt_builder: Arc<PromptBuilder>,
255 tools: Arc<ToolWorkingSet>,
256 tool_use: ToolUseState,
257 action_log: Entity<ActionLog>,
258 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
259 pending_checkpoint: Option<ThreadCheckpoint>,
260 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
261 cumulative_token_usage: TokenUsage,
262 feedback: Option<ThreadFeedback>,
263}
264
265impl Thread {
266 pub fn new(
267 project: Entity<Project>,
268 tools: Arc<ToolWorkingSet>,
269 prompt_builder: Arc<PromptBuilder>,
270 cx: &mut Context<Self>,
271 ) -> Self {
272 Self {
273 id: ThreadId::new(),
274 updated_at: Utc::now(),
275 summary: None,
276 pending_summary: Task::ready(None),
277 detailed_summary_state: DetailedSummaryState::NotGenerated,
278 messages: Vec::new(),
279 next_message_id: MessageId(0),
280 context: BTreeMap::default(),
281 context_by_message: HashMap::default(),
282 system_prompt_context: None,
283 checkpoints_by_message: HashMap::default(),
284 completion_count: 0,
285 pending_completions: Vec::new(),
286 project: project.clone(),
287 prompt_builder,
288 tools: tools.clone(),
289 last_restore_checkpoint: None,
290 pending_checkpoint: None,
291 tool_use: ToolUseState::new(tools.clone()),
292 action_log: cx.new(|_| ActionLog::new(project.clone())),
293 initial_project_snapshot: {
294 let project_snapshot = Self::project_snapshot(project, cx);
295 cx.foreground_executor()
296 .spawn(async move { Some(project_snapshot.await) })
297 .shared()
298 },
299 cumulative_token_usage: TokenUsage::default(),
300 feedback: None,
301 }
302 }
303
304 pub fn deserialize(
305 id: ThreadId,
306 serialized: SerializedThread,
307 project: Entity<Project>,
308 tools: Arc<ToolWorkingSet>,
309 prompt_builder: Arc<PromptBuilder>,
310 cx: &mut Context<Self>,
311 ) -> Self {
312 let next_message_id = MessageId(
313 serialized
314 .messages
315 .last()
316 .map(|message| message.id.0 + 1)
317 .unwrap_or(0),
318 );
319 let tool_use =
320 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
321
322 Self {
323 id,
324 updated_at: serialized.updated_at,
325 summary: Some(serialized.summary),
326 pending_summary: Task::ready(None),
327 detailed_summary_state: serialized.detailed_summary_state,
328 messages: serialized
329 .messages
330 .into_iter()
331 .map(|message| Message {
332 id: message.id,
333 role: message.role,
334 segments: message
335 .segments
336 .into_iter()
337 .map(|segment| match segment {
338 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
339 SerializedMessageSegment::Thinking { text } => {
340 MessageSegment::Thinking(text)
341 }
342 })
343 .collect(),
344 context: message.context,
345 })
346 .collect(),
347 next_message_id,
348 context: BTreeMap::default(),
349 context_by_message: HashMap::default(),
350 system_prompt_context: None,
351 checkpoints_by_message: HashMap::default(),
352 completion_count: 0,
353 pending_completions: Vec::new(),
354 last_restore_checkpoint: None,
355 pending_checkpoint: None,
356 project: project.clone(),
357 prompt_builder,
358 tools,
359 tool_use,
360 action_log: cx.new(|_| ActionLog::new(project)),
361 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
362 cumulative_token_usage: serialized.cumulative_token_usage,
363 feedback: None,
364 }
365 }
366
367 pub fn id(&self) -> &ThreadId {
368 &self.id
369 }
370
371 pub fn is_empty(&self) -> bool {
372 self.messages.is_empty()
373 }
374
375 pub fn updated_at(&self) -> DateTime<Utc> {
376 self.updated_at
377 }
378
379 pub fn touch_updated_at(&mut self) {
380 self.updated_at = Utc::now();
381 }
382
383 pub fn summary(&self) -> Option<SharedString> {
384 self.summary.clone()
385 }
386
387 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
388
389 pub fn summary_or_default(&self) -> SharedString {
390 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
391 }
392
393 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
394 let Some(current_summary) = &self.summary else {
395 // Don't allow setting summary until generated
396 return;
397 };
398
399 let mut new_summary = new_summary.into();
400
401 if new_summary.is_empty() {
402 new_summary = Self::DEFAULT_SUMMARY;
403 }
404
405 if current_summary != &new_summary {
406 self.summary = Some(new_summary);
407 cx.emit(ThreadEvent::SummaryChanged);
408 }
409 }
410
411 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
412 self.latest_detailed_summary()
413 .unwrap_or_else(|| self.text().into())
414 }
415
416 fn latest_detailed_summary(&self) -> Option<SharedString> {
417 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
418 Some(text.clone())
419 } else {
420 None
421 }
422 }
423
424 pub fn message(&self, id: MessageId) -> Option<&Message> {
425 self.messages.iter().find(|message| message.id == id)
426 }
427
428 pub fn messages(&self) -> impl Iterator<Item = &Message> {
429 self.messages.iter()
430 }
431
432 pub fn is_generating(&self) -> bool {
433 !self.pending_completions.is_empty() || !self.all_tools_finished()
434 }
435
436 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
437 &self.tools
438 }
439
440 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
441 self.tool_use
442 .pending_tool_uses()
443 .into_iter()
444 .find(|tool_use| &tool_use.id == id)
445 }
446
447 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
448 self.tool_use
449 .pending_tool_uses()
450 .into_iter()
451 .filter(|tool_use| tool_use.status.needs_confirmation())
452 }
453
454 pub fn has_pending_tool_uses(&self) -> bool {
455 !self.tool_use.pending_tool_uses().is_empty()
456 }
457
458 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
459 self.checkpoints_by_message.get(&id).cloned()
460 }
461
462 pub fn restore_checkpoint(
463 &mut self,
464 checkpoint: ThreadCheckpoint,
465 cx: &mut Context<Self>,
466 ) -> Task<Result<()>> {
467 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
468 message_id: checkpoint.message_id,
469 });
470 cx.emit(ThreadEvent::CheckpointChanged);
471 cx.notify();
472
473 let git_store = self.project().read(cx).git_store().clone();
474 let restore = git_store.update(cx, |git_store, cx| {
475 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
476 });
477
478 cx.spawn(async move |this, cx| {
479 let result = restore.await;
480 this.update(cx, |this, cx| {
481 if let Err(err) = result.as_ref() {
482 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
483 message_id: checkpoint.message_id,
484 error: err.to_string(),
485 });
486 } else {
487 this.truncate(checkpoint.message_id, cx);
488 this.last_restore_checkpoint = None;
489 }
490 this.pending_checkpoint = None;
491 cx.emit(ThreadEvent::CheckpointChanged);
492 cx.notify();
493 })?;
494 result
495 })
496 }
497
498 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
499 let pending_checkpoint = if self.is_generating() {
500 return;
501 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
502 checkpoint
503 } else {
504 return;
505 };
506
507 let git_store = self.project.read(cx).git_store().clone();
508 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
509 cx.spawn(async move |this, cx| match final_checkpoint.await {
510 Ok(final_checkpoint) => {
511 let equal = git_store
512 .update(cx, |store, cx| {
513 store.compare_checkpoints(
514 pending_checkpoint.git_checkpoint.clone(),
515 final_checkpoint.clone(),
516 cx,
517 )
518 })?
519 .await
520 .unwrap_or(false);
521
522 if equal {
523 git_store
524 .update(cx, |store, cx| {
525 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
526 })?
527 .detach();
528 } else {
529 this.update(cx, |this, cx| {
530 this.insert_checkpoint(pending_checkpoint, cx)
531 })?;
532 }
533
534 git_store
535 .update(cx, |store, cx| {
536 store.delete_checkpoint(final_checkpoint, cx)
537 })?
538 .detach();
539
540 Ok(())
541 }
542 Err(_) => this.update(cx, |this, cx| {
543 this.insert_checkpoint(pending_checkpoint, cx)
544 }),
545 })
546 .detach();
547 }
548
549 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
550 self.checkpoints_by_message
551 .insert(checkpoint.message_id, checkpoint);
552 cx.emit(ThreadEvent::CheckpointChanged);
553 cx.notify();
554 }
555
556 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
557 self.last_restore_checkpoint.as_ref()
558 }
559
560 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
561 let Some(message_ix) = self
562 .messages
563 .iter()
564 .rposition(|message| message.id == message_id)
565 else {
566 return;
567 };
568 for deleted_message in self.messages.drain(message_ix..) {
569 self.context_by_message.remove(&deleted_message.id);
570 self.checkpoints_by_message.remove(&deleted_message.id);
571 }
572 cx.notify();
573 }
574
575 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
576 self.context_by_message
577 .get(&id)
578 .into_iter()
579 .flat_map(|context| {
580 context
581 .iter()
582 .filter_map(|context_id| self.context.get(&context_id))
583 })
584 }
585
586 /// Returns whether all of the tool uses have finished running.
587 pub fn all_tools_finished(&self) -> bool {
588 // If the only pending tool uses left are the ones with errors, then
589 // that means that we've finished running all of the pending tools.
590 self.tool_use
591 .pending_tool_uses()
592 .iter()
593 .all(|tool_use| tool_use.status.is_error())
594 }
595
596 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
597 self.tool_use.tool_uses_for_message(id, cx)
598 }
599
600 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
601 self.tool_use.tool_results_for_message(id)
602 }
603
604 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
605 self.tool_use.tool_result(id)
606 }
607
608 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
609 self.tool_use.message_has_tool_results(message_id)
610 }
611
612 pub fn insert_user_message(
613 &mut self,
614 text: impl Into<String>,
615 context: Vec<AssistantContext>,
616 git_checkpoint: Option<GitStoreCheckpoint>,
617 cx: &mut Context<Self>,
618 ) -> MessageId {
619 let text = text.into();
620
621 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
622
623 // Filter out contexts that have already been included in previous messages
624 let new_context: Vec<_> = context
625 .into_iter()
626 .filter(|ctx| !self.context.contains_key(&ctx.id()))
627 .collect();
628
629 if !new_context.is_empty() {
630 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
631 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
632 message.context = context_string;
633 }
634 }
635
636 self.action_log.update(cx, |log, cx| {
637 // Track all buffers added as context
638 for ctx in &new_context {
639 match ctx {
640 AssistantContext::File(file_ctx) => {
641 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
642 }
643 AssistantContext::Directory(dir_ctx) => {
644 for context_buffer in &dir_ctx.context_buffers {
645 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
646 }
647 }
648 AssistantContext::Symbol(symbol_ctx) => {
649 log.buffer_added_as_context(
650 symbol_ctx.context_symbol.buffer.clone(),
651 cx,
652 );
653 }
654 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
655 }
656 }
657 });
658 }
659
660 let context_ids = new_context
661 .iter()
662 .map(|context| context.id())
663 .collect::<Vec<_>>();
664 self.context.extend(
665 new_context
666 .into_iter()
667 .map(|context| (context.id(), context)),
668 );
669 self.context_by_message.insert(message_id, context_ids);
670
671 if let Some(git_checkpoint) = git_checkpoint {
672 self.pending_checkpoint = Some(ThreadCheckpoint {
673 message_id,
674 git_checkpoint,
675 });
676 }
677 message_id
678 }
679
680 pub fn insert_message(
681 &mut self,
682 role: Role,
683 segments: Vec<MessageSegment>,
684 cx: &mut Context<Self>,
685 ) -> MessageId {
686 let id = self.next_message_id.post_inc();
687 self.messages.push(Message {
688 id,
689 role,
690 segments,
691 context: String::new(),
692 });
693 self.touch_updated_at();
694 cx.emit(ThreadEvent::MessageAdded(id));
695 id
696 }
697
698 pub fn edit_message(
699 &mut self,
700 id: MessageId,
701 new_role: Role,
702 new_segments: Vec<MessageSegment>,
703 cx: &mut Context<Self>,
704 ) -> bool {
705 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
706 return false;
707 };
708 message.role = new_role;
709 message.segments = new_segments;
710 self.touch_updated_at();
711 cx.emit(ThreadEvent::MessageEdited(id));
712 true
713 }
714
715 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
716 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
717 return false;
718 };
719 self.messages.remove(index);
720 self.context_by_message.remove(&id);
721 self.touch_updated_at();
722 cx.emit(ThreadEvent::MessageDeleted(id));
723 true
724 }
725
726 /// Returns the representation of this [`Thread`] in a textual form.
727 ///
728 /// This is the representation we use when attaching a thread as context to another thread.
729 pub fn text(&self) -> String {
730 let mut text = String::new();
731
732 for message in &self.messages {
733 text.push_str(match message.role {
734 language_model::Role::User => "User:",
735 language_model::Role::Assistant => "Assistant:",
736 language_model::Role::System => "System:",
737 });
738 text.push('\n');
739
740 for segment in &message.segments {
741 match segment {
742 MessageSegment::Text(content) => text.push_str(content),
743 MessageSegment::Thinking(content) => {
744 text.push_str(&format!("<think>{}</think>", content))
745 }
746 }
747 }
748 text.push('\n');
749 }
750
751 text
752 }
753
754 /// Serializes this thread into a format for storage or telemetry.
755 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
756 let initial_project_snapshot = self.initial_project_snapshot.clone();
757 cx.spawn(async move |this, cx| {
758 let initial_project_snapshot = initial_project_snapshot.await;
759 this.read_with(cx, |this, cx| SerializedThread {
760 version: SerializedThread::VERSION.to_string(),
761 summary: this.summary_or_default(),
762 updated_at: this.updated_at(),
763 messages: this
764 .messages()
765 .map(|message| SerializedMessage {
766 id: message.id,
767 role: message.role,
768 segments: message
769 .segments
770 .iter()
771 .map(|segment| match segment {
772 MessageSegment::Text(text) => {
773 SerializedMessageSegment::Text { text: text.clone() }
774 }
775 MessageSegment::Thinking(text) => {
776 SerializedMessageSegment::Thinking { text: text.clone() }
777 }
778 })
779 .collect(),
780 tool_uses: this
781 .tool_uses_for_message(message.id, cx)
782 .into_iter()
783 .map(|tool_use| SerializedToolUse {
784 id: tool_use.id,
785 name: tool_use.name,
786 input: tool_use.input,
787 })
788 .collect(),
789 tool_results: this
790 .tool_results_for_message(message.id)
791 .into_iter()
792 .map(|tool_result| SerializedToolResult {
793 tool_use_id: tool_result.tool_use_id.clone(),
794 is_error: tool_result.is_error,
795 content: tool_result.content.clone(),
796 })
797 .collect(),
798 context: message.context.clone(),
799 })
800 .collect(),
801 initial_project_snapshot,
802 cumulative_token_usage: this.cumulative_token_usage.clone(),
803 detailed_summary_state: this.detailed_summary_state.clone(),
804 })
805 })
806 }
807
808 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
809 self.system_prompt_context = Some(context);
810 }
811
812 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
813 &self.system_prompt_context
814 }
815
816 pub fn load_system_prompt_context(
817 &self,
818 cx: &App,
819 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
820 let project = self.project.read(cx);
821 let tasks = project
822 .visible_worktrees(cx)
823 .map(|worktree| {
824 Self::load_worktree_info_for_system_prompt(
825 project.fs().clone(),
826 worktree.read(cx),
827 cx,
828 )
829 })
830 .collect::<Vec<_>>();
831
832 cx.spawn(async |_cx| {
833 let results = futures::future::join_all(tasks).await;
834 let mut first_err = None;
835 let worktrees = results
836 .into_iter()
837 .map(|(worktree, err)| {
838 if first_err.is_none() && err.is_some() {
839 first_err = err;
840 }
841 worktree
842 })
843 .collect::<Vec<_>>();
844 (AssistantSystemPromptContext::new(worktrees), first_err)
845 })
846 }
847
848 fn load_worktree_info_for_system_prompt(
849 fs: Arc<dyn Fs>,
850 worktree: &Worktree,
851 cx: &App,
852 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
853 let root_name = worktree.root_name().into();
854 let abs_path = worktree.abs_path();
855
856 let rules_task = load_worktree_rules_file(fs, worktree, cx);
857 let Some(rules_task) = rules_task else {
858 return Task::ready((
859 WorktreeInfoForSystemPrompt {
860 root_name,
861 abs_path,
862 rules_file: None,
863 },
864 None,
865 ));
866 };
867
868 cx.spawn(async move |_| {
869 let (rules_file, rules_file_error) = match rules_task.await {
870 Ok(rules_file) => (Some(rules_file), None),
871 Err(err) => (
872 None,
873 Some(ThreadError::Message {
874 header: "Error loading rules file".into(),
875 message: format!("{err}").into(),
876 }),
877 ),
878 };
879 let worktree_info = WorktreeInfoForSystemPrompt {
880 root_name,
881 abs_path,
882 rules_file,
883 };
884 (worktree_info, rules_file_error)
885 })
886 }
887
888 pub fn send_to_model(
889 &mut self,
890 model: Arc<dyn LanguageModel>,
891 request_kind: RequestKind,
892 cx: &mut Context<Self>,
893 ) {
894 let mut request = self.to_completion_request(request_kind, cx);
895 if model.supports_tools() {
896 request.tools = {
897 let mut tools = Vec::new();
898 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
899 LanguageModelRequestTool {
900 name: tool.name(),
901 description: tool.description(),
902 input_schema: tool.input_schema(model.tool_input_format()),
903 }
904 }));
905
906 tools
907 };
908 }
909
910 self.stream_completion(request, model, cx);
911 }
912
913 pub fn used_tools_since_last_user_message(&self) -> bool {
914 for message in self.messages.iter().rev() {
915 if self.tool_use.message_has_tool_results(message.id) {
916 return true;
917 } else if message.role == Role::User {
918 return false;
919 }
920 }
921
922 false
923 }
924
925 pub fn to_completion_request(
926 &self,
927 request_kind: RequestKind,
928 cx: &App,
929 ) -> LanguageModelRequest {
930 let mut request = LanguageModelRequest {
931 messages: vec![],
932 tools: Vec::new(),
933 stop: Vec::new(),
934 temperature: None,
935 };
936
937 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
938 if let Some(system_prompt) = self
939 .prompt_builder
940 .generate_assistant_system_prompt(system_prompt_context)
941 .context("failed to generate assistant system prompt")
942 .log_err()
943 {
944 request.messages.push(LanguageModelRequestMessage {
945 role: Role::System,
946 content: vec![MessageContent::Text(system_prompt)],
947 cache: true,
948 });
949 }
950 } else {
951 log::error!("system_prompt_context not set.")
952 }
953
954 for message in &self.messages {
955 let mut request_message = LanguageModelRequestMessage {
956 role: message.role,
957 content: Vec::new(),
958 cache: false,
959 };
960
961 match request_kind {
962 RequestKind::Chat => {
963 self.tool_use
964 .attach_tool_results(message.id, &mut request_message);
965 }
966 RequestKind::Summarize => {
967 // We don't care about tool use during summarization.
968 if self.tool_use.message_has_tool_results(message.id) {
969 continue;
970 }
971 }
972 }
973
974 if !message.segments.is_empty() {
975 request_message
976 .content
977 .push(MessageContent::Text(message.to_string()));
978 }
979
980 match request_kind {
981 RequestKind::Chat => {
982 self.tool_use
983 .attach_tool_uses(message.id, &mut request_message);
984 }
985 RequestKind::Summarize => {
986 // We don't care about tool use during summarization.
987 }
988 };
989
990 request.messages.push(request_message);
991 }
992
993 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
994 if let Some(last) = request.messages.last_mut() {
995 last.cache = true;
996 }
997
998 self.attached_tracked_files_state(&mut request.messages, cx);
999
1000 request
1001 }
1002
1003 fn attached_tracked_files_state(
1004 &self,
1005 messages: &mut Vec<LanguageModelRequestMessage>,
1006 cx: &App,
1007 ) {
1008 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1009
1010 let mut stale_message = String::new();
1011
1012 let action_log = self.action_log.read(cx);
1013
1014 for stale_file in action_log.stale_buffers(cx) {
1015 let Some(file) = stale_file.read(cx).file() else {
1016 continue;
1017 };
1018
1019 if stale_message.is_empty() {
1020 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1021 }
1022
1023 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1024 }
1025
1026 let mut content = Vec::with_capacity(2);
1027
1028 if !stale_message.is_empty() {
1029 content.push(stale_message.into());
1030 }
1031
1032 if action_log.has_edited_files_since_project_diagnostics_check() {
1033 content.push(
1034 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1035 and fix all errors AND warnings you introduced! \
1036 DO NOT mention you're going to do this until you're done."
1037 .into(),
1038 );
1039 }
1040
1041 if !content.is_empty() {
1042 let context_message = LanguageModelRequestMessage {
1043 role: Role::User,
1044 content,
1045 cache: false,
1046 };
1047
1048 messages.push(context_message);
1049 }
1050 }
1051
1052 pub fn stream_completion(
1053 &mut self,
1054 request: LanguageModelRequest,
1055 model: Arc<dyn LanguageModel>,
1056 cx: &mut Context<Self>,
1057 ) {
1058 let pending_completion_id = post_inc(&mut self.completion_count);
1059
1060 let task = cx.spawn(async move |thread, cx| {
1061 let stream = model.stream_completion(request, &cx);
1062 let initial_token_usage =
1063 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1064 let stream_completion = async {
1065 let mut events = stream.await?;
1066 let mut stop_reason = StopReason::EndTurn;
1067 let mut current_token_usage = TokenUsage::default();
1068
1069 while let Some(event) = events.next().await {
1070 let event = event?;
1071
1072 thread.update(cx, |thread, cx| {
1073 match event {
1074 LanguageModelCompletionEvent::StartMessage { .. } => {
1075 thread.insert_message(
1076 Role::Assistant,
1077 vec![MessageSegment::Text(String::new())],
1078 cx,
1079 );
1080 }
1081 LanguageModelCompletionEvent::Stop(reason) => {
1082 stop_reason = reason;
1083 }
1084 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1085 thread.cumulative_token_usage =
1086 thread.cumulative_token_usage.clone() + token_usage.clone()
1087 - current_token_usage.clone();
1088 current_token_usage = token_usage;
1089 }
1090 LanguageModelCompletionEvent::Text(chunk) => {
1091 if let Some(last_message) = thread.messages.last_mut() {
1092 if last_message.role == Role::Assistant {
1093 last_message.push_text(&chunk);
1094 cx.emit(ThreadEvent::StreamedAssistantText(
1095 last_message.id,
1096 chunk,
1097 ));
1098 } else {
1099 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1100 // of a new Assistant response.
1101 //
1102 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1103 // will result in duplicating the text of the chunk in the rendered Markdown.
1104 thread.insert_message(
1105 Role::Assistant,
1106 vec![MessageSegment::Text(chunk.to_string())],
1107 cx,
1108 );
1109 };
1110 }
1111 }
1112 LanguageModelCompletionEvent::Thinking(chunk) => {
1113 if let Some(last_message) = thread.messages.last_mut() {
1114 if last_message.role == Role::Assistant {
1115 last_message.push_thinking(&chunk);
1116 cx.emit(ThreadEvent::StreamedAssistantThinking(
1117 last_message.id,
1118 chunk,
1119 ));
1120 } else {
1121 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122 // of a new Assistant response.
1123 //
1124 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125 // will result in duplicating the text of the chunk in the rendered Markdown.
1126 thread.insert_message(
1127 Role::Assistant,
1128 vec![MessageSegment::Thinking(chunk.to_string())],
1129 cx,
1130 );
1131 };
1132 }
1133 }
1134 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1135 let last_assistant_message_id = thread
1136 .messages
1137 .iter_mut()
1138 .rfind(|message| message.role == Role::Assistant)
1139 .map(|message| message.id)
1140 .unwrap_or_else(|| {
1141 thread.insert_message(Role::Assistant, vec![], cx)
1142 });
1143
1144 thread.tool_use.request_tool_use(
1145 last_assistant_message_id,
1146 tool_use,
1147 cx,
1148 );
1149 }
1150 }
1151
1152 thread.touch_updated_at();
1153 cx.emit(ThreadEvent::StreamedCompletion);
1154 cx.notify();
1155 })?;
1156
1157 smol::future::yield_now().await;
1158 }
1159
1160 thread.update(cx, |thread, cx| {
1161 thread
1162 .pending_completions
1163 .retain(|completion| completion.id != pending_completion_id);
1164
1165 if thread.summary.is_none() && thread.messages.len() >= 2 {
1166 thread.summarize(cx);
1167 }
1168 })?;
1169
1170 anyhow::Ok(stop_reason)
1171 };
1172
1173 let result = stream_completion.await;
1174
1175 thread
1176 .update(cx, |thread, cx| {
1177 thread.finalize_pending_checkpoint(cx);
1178 match result.as_ref() {
1179 Ok(stop_reason) => match stop_reason {
1180 StopReason::ToolUse => {
1181 cx.emit(ThreadEvent::UsePendingTools);
1182 }
1183 StopReason::EndTurn => {}
1184 StopReason::MaxTokens => {}
1185 },
1186 Err(error) => {
1187 if error.is::<PaymentRequiredError>() {
1188 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1189 } else if error.is::<MaxMonthlySpendReachedError>() {
1190 cx.emit(ThreadEvent::ShowError(
1191 ThreadError::MaxMonthlySpendReached,
1192 ));
1193 } else {
1194 let error_message = error
1195 .chain()
1196 .map(|err| err.to_string())
1197 .collect::<Vec<_>>()
1198 .join("\n");
1199 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1200 header: "Error interacting with language model".into(),
1201 message: SharedString::from(error_message.clone()),
1202 }));
1203 }
1204
1205 thread.cancel_last_completion(cx);
1206 }
1207 }
1208 cx.emit(ThreadEvent::DoneStreaming);
1209
1210 if let Ok(initial_usage) = initial_token_usage {
1211 let usage = thread.cumulative_token_usage.clone() - initial_usage;
1212
1213 telemetry::event!(
1214 "Assistant Thread Completion",
1215 thread_id = thread.id().to_string(),
1216 model = model.telemetry_id(),
1217 model_provider = model.provider_id().to_string(),
1218 input_tokens = usage.input_tokens,
1219 output_tokens = usage.output_tokens,
1220 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1221 cache_read_input_tokens = usage.cache_read_input_tokens,
1222 );
1223 }
1224 })
1225 .ok();
1226 });
1227
1228 self.pending_completions.push(PendingCompletion {
1229 id: pending_completion_id,
1230 _task: task,
1231 });
1232 }
1233
1234 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1235 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1236 return;
1237 };
1238
1239 if !model.provider.is_authenticated(cx) {
1240 return;
1241 }
1242
1243 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1244 request.messages.push(LanguageModelRequestMessage {
1245 role: Role::User,
1246 content: vec![
1247 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1248 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1249 If the conversation is about a specific subject, include it in the title. \
1250 Be descriptive. DO NOT speak in the first person."
1251 .into(),
1252 ],
1253 cache: false,
1254 });
1255
1256 self.pending_summary = cx.spawn(async move |this, cx| {
1257 async move {
1258 let stream = model.model.stream_completion_text(request, &cx);
1259 let mut messages = stream.await?;
1260
1261 let mut new_summary = String::new();
1262 while let Some(message) = messages.stream.next().await {
1263 let text = message?;
1264 let mut lines = text.lines();
1265 new_summary.extend(lines.next());
1266
1267 // Stop if the LLM generated multiple lines.
1268 if lines.next().is_some() {
1269 break;
1270 }
1271 }
1272
1273 this.update(cx, |this, cx| {
1274 if !new_summary.is_empty() {
1275 this.summary = Some(new_summary.into());
1276 }
1277
1278 cx.emit(ThreadEvent::SummaryGenerated);
1279 })?;
1280
1281 anyhow::Ok(())
1282 }
1283 .log_err()
1284 .await
1285 });
1286 }
1287
1288 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1289 let last_message_id = self.messages.last().map(|message| message.id)?;
1290
1291 match &self.detailed_summary_state {
1292 DetailedSummaryState::Generating { message_id, .. }
1293 | DetailedSummaryState::Generated { message_id, .. }
1294 if *message_id == last_message_id =>
1295 {
1296 // Already up-to-date
1297 return None;
1298 }
1299 _ => {}
1300 }
1301
1302 let ConfiguredModel { model, provider } =
1303 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1304
1305 if !provider.is_authenticated(cx) {
1306 return None;
1307 }
1308
1309 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1310
1311 request.messages.push(LanguageModelRequestMessage {
1312 role: Role::User,
1313 content: vec![
1314 "Generate a detailed summary of this conversation. Include:\n\
1315 1. A brief overview of what was discussed\n\
1316 2. Key facts or information discovered\n\
1317 3. Outcomes or conclusions reached\n\
1318 4. Any action items or next steps if any\n\
1319 Format it in Markdown with headings and bullet points."
1320 .into(),
1321 ],
1322 cache: false,
1323 });
1324
1325 let task = cx.spawn(async move |thread, cx| {
1326 let stream = model.stream_completion_text(request, &cx);
1327 let Some(mut messages) = stream.await.log_err() else {
1328 thread
1329 .update(cx, |this, _cx| {
1330 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1331 })
1332 .log_err();
1333
1334 return;
1335 };
1336
1337 let mut new_detailed_summary = String::new();
1338
1339 while let Some(chunk) = messages.stream.next().await {
1340 if let Some(chunk) = chunk.log_err() {
1341 new_detailed_summary.push_str(&chunk);
1342 }
1343 }
1344
1345 thread
1346 .update(cx, |this, _cx| {
1347 this.detailed_summary_state = DetailedSummaryState::Generated {
1348 text: new_detailed_summary.into(),
1349 message_id: last_message_id,
1350 };
1351 })
1352 .log_err();
1353 });
1354
1355 self.detailed_summary_state = DetailedSummaryState::Generating {
1356 message_id: last_message_id,
1357 };
1358
1359 Some(task)
1360 }
1361
1362 pub fn is_generating_detailed_summary(&self) -> bool {
1363 matches!(
1364 self.detailed_summary_state,
1365 DetailedSummaryState::Generating { .. }
1366 )
1367 }
1368
1369 pub fn use_pending_tools(
1370 &mut self,
1371 cx: &mut Context<Self>,
1372 ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1373 let request = self.to_completion_request(RequestKind::Chat, cx);
1374 let messages = Arc::new(request.messages);
1375 let pending_tool_uses = self
1376 .tool_use
1377 .pending_tool_uses()
1378 .into_iter()
1379 .filter(|tool_use| tool_use.status.is_idle())
1380 .cloned()
1381 .collect::<Vec<_>>();
1382
1383 for tool_use in pending_tool_uses.iter() {
1384 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1385 if tool.needs_confirmation(&tool_use.input, cx)
1386 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1387 {
1388 self.tool_use.confirm_tool_use(
1389 tool_use.id.clone(),
1390 tool_use.ui_text.clone(),
1391 tool_use.input.clone(),
1392 messages.clone(),
1393 tool,
1394 );
1395 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1396 } else {
1397 self.run_tool(
1398 tool_use.id.clone(),
1399 tool_use.ui_text.clone(),
1400 tool_use.input.clone(),
1401 &messages,
1402 tool,
1403 cx,
1404 );
1405 }
1406 }
1407 }
1408
1409 pending_tool_uses
1410 }
1411
1412 pub fn run_tool(
1413 &mut self,
1414 tool_use_id: LanguageModelToolUseId,
1415 ui_text: impl Into<SharedString>,
1416 input: serde_json::Value,
1417 messages: &[LanguageModelRequestMessage],
1418 tool: Arc<dyn Tool>,
1419 cx: &mut Context<Thread>,
1420 ) {
1421 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1422 self.tool_use
1423 .run_pending_tool(tool_use_id, ui_text.into(), task);
1424 }
1425
1426 fn spawn_tool_use(
1427 &mut self,
1428 tool_use_id: LanguageModelToolUseId,
1429 messages: &[LanguageModelRequestMessage],
1430 input: serde_json::Value,
1431 tool: Arc<dyn Tool>,
1432 cx: &mut Context<Thread>,
1433 ) -> Task<()> {
1434 let tool_name: Arc<str> = tool.name().into();
1435
1436 let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1437 Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1438 } else {
1439 tool.run(
1440 input,
1441 messages,
1442 self.project.clone(),
1443 self.action_log.clone(),
1444 cx,
1445 )
1446 };
1447
1448 cx.spawn({
1449 async move |thread: WeakEntity<Thread>, cx| {
1450 let output = run_tool.await;
1451
1452 thread
1453 .update(cx, |thread, cx| {
1454 let pending_tool_use = thread.tool_use.insert_tool_output(
1455 tool_use_id.clone(),
1456 tool_name,
1457 output,
1458 cx,
1459 );
1460
1461 cx.emit(ThreadEvent::ToolFinished {
1462 tool_use_id,
1463 pending_tool_use,
1464 canceled: false,
1465 });
1466 })
1467 .ok();
1468 }
1469 })
1470 }
1471
1472 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1473 // Insert a user message to contain the tool results.
1474 self.insert_user_message(
1475 // TODO: Sending up a user message without any content results in the model sending back
1476 // responses that also don't have any content. We currently don't handle this case well,
1477 // so for now we provide some text to keep the model on track.
1478 "Here are the tool results.",
1479 Vec::new(),
1480 None,
1481 cx,
1482 );
1483 }
1484
1485 /// Cancels the last pending completion, if there are any pending.
1486 ///
1487 /// Returns whether a completion was canceled.
1488 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1489 let canceled = if self.pending_completions.pop().is_some() {
1490 true
1491 } else {
1492 let mut canceled = false;
1493 for pending_tool_use in self.tool_use.cancel_pending() {
1494 canceled = true;
1495 cx.emit(ThreadEvent::ToolFinished {
1496 tool_use_id: pending_tool_use.id.clone(),
1497 pending_tool_use: Some(pending_tool_use),
1498 canceled: true,
1499 });
1500 }
1501 canceled
1502 };
1503 self.finalize_pending_checkpoint(cx);
1504 canceled
1505 }
1506
1507 /// Returns the feedback given to the thread, if any.
1508 pub fn feedback(&self) -> Option<ThreadFeedback> {
1509 self.feedback
1510 }
1511
1512 /// Reports feedback about the thread and stores it in our telemetry backend.
1513 pub fn report_feedback(
1514 &mut self,
1515 feedback: ThreadFeedback,
1516 cx: &mut Context<Self>,
1517 ) -> Task<Result<()>> {
1518 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1519 let serialized_thread = self.serialize(cx);
1520 let thread_id = self.id().clone();
1521 let client = self.project.read(cx).client();
1522 self.feedback = Some(feedback);
1523 cx.notify();
1524
1525 cx.background_spawn(async move {
1526 let final_project_snapshot = final_project_snapshot.await;
1527 let serialized_thread = serialized_thread.await?;
1528 let thread_data =
1529 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1530
1531 let rating = match feedback {
1532 ThreadFeedback::Positive => "positive",
1533 ThreadFeedback::Negative => "negative",
1534 };
1535 telemetry::event!(
1536 "Assistant Thread Rated",
1537 rating,
1538 thread_id,
1539 thread_data,
1540 final_project_snapshot
1541 );
1542 client.telemetry().flush_events();
1543
1544 Ok(())
1545 })
1546 }
1547
1548 /// Create a snapshot of the current project state including git information and unsaved buffers.
1549 fn project_snapshot(
1550 project: Entity<Project>,
1551 cx: &mut Context<Self>,
1552 ) -> Task<Arc<ProjectSnapshot>> {
1553 let git_store = project.read(cx).git_store().clone();
1554 let worktree_snapshots: Vec<_> = project
1555 .read(cx)
1556 .visible_worktrees(cx)
1557 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1558 .collect();
1559
1560 cx.spawn(async move |_, cx| {
1561 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1562
1563 let mut unsaved_buffers = Vec::new();
1564 cx.update(|app_cx| {
1565 let buffer_store = project.read(app_cx).buffer_store();
1566 for buffer_handle in buffer_store.read(app_cx).buffers() {
1567 let buffer = buffer_handle.read(app_cx);
1568 if buffer.is_dirty() {
1569 if let Some(file) = buffer.file() {
1570 let path = file.path().to_string_lossy().to_string();
1571 unsaved_buffers.push(path);
1572 }
1573 }
1574 }
1575 })
1576 .ok();
1577
1578 Arc::new(ProjectSnapshot {
1579 worktree_snapshots,
1580 unsaved_buffer_paths: unsaved_buffers,
1581 timestamp: Utc::now(),
1582 })
1583 })
1584 }
1585
1586 fn worktree_snapshot(
1587 worktree: Entity<project::Worktree>,
1588 git_store: Entity<GitStore>,
1589 cx: &App,
1590 ) -> Task<WorktreeSnapshot> {
1591 cx.spawn(async move |cx| {
1592 // Get worktree path and snapshot
1593 let worktree_info = cx.update(|app_cx| {
1594 let worktree = worktree.read(app_cx);
1595 let path = worktree.abs_path().to_string_lossy().to_string();
1596 let snapshot = worktree.snapshot();
1597 (path, snapshot)
1598 });
1599
1600 let Ok((worktree_path, _snapshot)) = worktree_info else {
1601 return WorktreeSnapshot {
1602 worktree_path: String::new(),
1603 git_state: None,
1604 };
1605 };
1606
1607 let git_state = git_store
1608 .update(cx, |git_store, cx| {
1609 git_store
1610 .repositories()
1611 .values()
1612 .find(|repo| {
1613 repo.read(cx)
1614 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1615 .is_some()
1616 })
1617 .cloned()
1618 })
1619 .ok()
1620 .flatten()
1621 .map(|repo| {
1622 repo.update(cx, |repo, _| {
1623 let current_branch =
1624 repo.branch.as_ref().map(|branch| branch.name.to_string());
1625 repo.send_job(None, |state, _| async move {
1626 let RepositoryState::Local { backend, .. } = state else {
1627 return GitState {
1628 remote_url: None,
1629 head_sha: None,
1630 current_branch,
1631 diff: None,
1632 };
1633 };
1634
1635 let remote_url = backend.remote_url("origin");
1636 let head_sha = backend.head_sha();
1637 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1638
1639 GitState {
1640 remote_url,
1641 head_sha,
1642 current_branch,
1643 diff,
1644 }
1645 })
1646 })
1647 });
1648
1649 let git_state = match git_state {
1650 Some(git_state) => match git_state.ok() {
1651 Some(git_state) => git_state.await.ok(),
1652 None => None,
1653 },
1654 None => None,
1655 };
1656
1657 WorktreeSnapshot {
1658 worktree_path,
1659 git_state,
1660 }
1661 })
1662 }
1663
1664 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1665 let mut markdown = Vec::new();
1666
1667 if let Some(summary) = self.summary() {
1668 writeln!(markdown, "# {summary}\n")?;
1669 };
1670
1671 for message in self.messages() {
1672 writeln!(
1673 markdown,
1674 "## {role}\n",
1675 role = match message.role {
1676 Role::User => "User",
1677 Role::Assistant => "Assistant",
1678 Role::System => "System",
1679 }
1680 )?;
1681
1682 if !message.context.is_empty() {
1683 writeln!(markdown, "{}", message.context)?;
1684 }
1685
1686 for segment in &message.segments {
1687 match segment {
1688 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1689 MessageSegment::Thinking(text) => {
1690 writeln!(markdown, "<think>{}</think>\n", text)?
1691 }
1692 }
1693 }
1694
1695 for tool_use in self.tool_uses_for_message(message.id, cx) {
1696 writeln!(
1697 markdown,
1698 "**Use Tool: {} ({})**",
1699 tool_use.name, tool_use.id
1700 )?;
1701 writeln!(markdown, "```json")?;
1702 writeln!(
1703 markdown,
1704 "{}",
1705 serde_json::to_string_pretty(&tool_use.input)?
1706 )?;
1707 writeln!(markdown, "```")?;
1708 }
1709
1710 for tool_result in self.tool_results_for_message(message.id) {
1711 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1712 if tool_result.is_error {
1713 write!(markdown, " (Error)")?;
1714 }
1715
1716 writeln!(markdown, "**\n")?;
1717 writeln!(markdown, "{}", tool_result.content)?;
1718 }
1719 }
1720
1721 Ok(String::from_utf8_lossy(&markdown).to_string())
1722 }
1723
1724 pub fn keep_edits_in_range(
1725 &mut self,
1726 buffer: Entity<language::Buffer>,
1727 buffer_range: Range<language::Anchor>,
1728 cx: &mut Context<Self>,
1729 ) {
1730 self.action_log.update(cx, |action_log, cx| {
1731 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1732 });
1733 }
1734
1735 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1736 self.action_log
1737 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1738 }
1739
1740 pub fn reject_edits_in_range(
1741 &mut self,
1742 buffer: Entity<language::Buffer>,
1743 buffer_range: Range<language::Anchor>,
1744 cx: &mut Context<Self>,
1745 ) -> Task<Result<()>> {
1746 self.action_log.update(cx, |action_log, cx| {
1747 action_log.reject_edits_in_range(buffer, buffer_range, cx)
1748 })
1749 }
1750
1751 pub fn action_log(&self) -> &Entity<ActionLog> {
1752 &self.action_log
1753 }
1754
1755 pub fn project(&self) -> &Entity<Project> {
1756 &self.project
1757 }
1758
1759 pub fn cumulative_token_usage(&self) -> TokenUsage {
1760 self.cumulative_token_usage.clone()
1761 }
1762
1763 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1764 let model_registry = LanguageModelRegistry::read_global(cx);
1765 let Some(model) = model_registry.default_model() else {
1766 return TotalTokenUsage::default();
1767 };
1768
1769 let max = model.model.max_token_count();
1770
1771 #[cfg(debug_assertions)]
1772 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1773 .unwrap_or("0.8".to_string())
1774 .parse()
1775 .unwrap();
1776 #[cfg(not(debug_assertions))]
1777 let warning_threshold: f32 = 0.8;
1778
1779 let total = self.cumulative_token_usage.total_tokens() as usize;
1780
1781 let ratio = if total >= max {
1782 TokenUsageRatio::Exceeded
1783 } else if total as f32 / max as f32 >= warning_threshold {
1784 TokenUsageRatio::Warning
1785 } else {
1786 TokenUsageRatio::Normal
1787 };
1788
1789 TotalTokenUsage { total, max, ratio }
1790 }
1791
1792 pub fn deny_tool_use(
1793 &mut self,
1794 tool_use_id: LanguageModelToolUseId,
1795 tool_name: Arc<str>,
1796 cx: &mut Context<Self>,
1797 ) {
1798 let err = Err(anyhow::anyhow!(
1799 "Permission to run tool action denied by user"
1800 ));
1801
1802 self.tool_use
1803 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1804
1805 cx.emit(ThreadEvent::ToolFinished {
1806 tool_use_id,
1807 pending_tool_use: None,
1808 canceled: true,
1809 });
1810 }
1811}
1812
1813#[derive(Debug, Clone)]
1814pub enum ThreadError {
1815 PaymentRequired,
1816 MaxMonthlySpendReached,
1817 Message {
1818 header: SharedString,
1819 message: SharedString,
1820 },
1821}
1822
1823#[derive(Debug, Clone)]
1824pub enum ThreadEvent {
1825 ShowError(ThreadError),
1826 StreamedCompletion,
1827 StreamedAssistantText(MessageId, String),
1828 StreamedAssistantThinking(MessageId, String),
1829 DoneStreaming,
1830 MessageAdded(MessageId),
1831 MessageEdited(MessageId),
1832 MessageDeleted(MessageId),
1833 SummaryGenerated,
1834 SummaryChanged,
1835 UsePendingTools,
1836 ToolFinished {
1837 #[allow(unused)]
1838 tool_use_id: LanguageModelToolUseId,
1839 /// The pending tool use that corresponds to this tool.
1840 pending_tool_use: Option<PendingToolUse>,
1841 /// Whether the tool was canceled by the user.
1842 canceled: bool,
1843 },
1844 CheckpointChanged,
1845 ToolConfirmationNeeded,
1846}
1847
1848impl EventEmitter<ThreadEvent> for Thread {}
1849
1850struct PendingCompletion {
1851 id: usize,
1852 _task: Task<()>,
1853}
1854
1855#[cfg(test)]
1856mod tests {
1857 use super::*;
1858 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1859 use assistant_settings::AssistantSettings;
1860 use context_server::ContextServerSettings;
1861 use editor::EditorSettings;
1862 use gpui::TestAppContext;
1863 use project::{FakeFs, Project};
1864 use prompt_store::PromptBuilder;
1865 use serde_json::json;
1866 use settings::{Settings, SettingsStore};
1867 use std::sync::Arc;
1868 use theme::ThemeSettings;
1869 use util::path;
1870 use workspace::Workspace;
1871
1872 #[gpui::test]
1873 async fn test_message_with_context(cx: &mut TestAppContext) {
1874 init_test_settings(cx);
1875
1876 let project = create_test_project(
1877 cx,
1878 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1879 )
1880 .await;
1881
1882 let (_workspace, _thread_store, thread, context_store) =
1883 setup_test_environment(cx, project.clone()).await;
1884
1885 add_file_to_context(&project, &context_store, "test/code.rs", cx)
1886 .await
1887 .unwrap();
1888
1889 let context =
1890 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1891
1892 // Insert user message with context
1893 let message_id = thread.update(cx, |thread, cx| {
1894 thread.insert_user_message("Please explain this code", vec![context], None, cx)
1895 });
1896
1897 // Check content and context in message object
1898 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1899
1900 // Use different path format strings based on platform for the test
1901 #[cfg(windows)]
1902 let path_part = r"test\code.rs";
1903 #[cfg(not(windows))]
1904 let path_part = "test/code.rs";
1905
1906 let expected_context = format!(
1907 r#"
1908<context>
1909The following items were attached by the user. You don't need to use other tools to read them.
1910
1911<files>
1912```rs {path_part}
1913fn main() {{
1914 println!("Hello, world!");
1915}}
1916```
1917</files>
1918</context>
1919"#
1920 );
1921
1922 assert_eq!(message.role, Role::User);
1923 assert_eq!(message.segments.len(), 1);
1924 assert_eq!(
1925 message.segments[0],
1926 MessageSegment::Text("Please explain this code".to_string())
1927 );
1928 assert_eq!(message.context, expected_context);
1929
1930 // Check message in request
1931 let request = thread.read_with(cx, |thread, cx| {
1932 thread.to_completion_request(RequestKind::Chat, cx)
1933 });
1934
1935 assert_eq!(request.messages.len(), 1);
1936 let expected_full_message = format!("{}Please explain this code", expected_context);
1937 assert_eq!(request.messages[0].string_contents(), expected_full_message);
1938 }
1939
1940 #[gpui::test]
1941 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
1942 init_test_settings(cx);
1943
1944 let project = create_test_project(
1945 cx,
1946 json!({
1947 "file1.rs": "fn function1() {}\n",
1948 "file2.rs": "fn function2() {}\n",
1949 "file3.rs": "fn function3() {}\n",
1950 }),
1951 )
1952 .await;
1953
1954 let (_, _thread_store, thread, context_store) =
1955 setup_test_environment(cx, project.clone()).await;
1956
1957 // Open files individually
1958 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
1959 .await
1960 .unwrap();
1961 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
1962 .await
1963 .unwrap();
1964 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
1965 .await
1966 .unwrap();
1967
1968 // Get the context objects
1969 let contexts = context_store.update(cx, |store, _| store.context().clone());
1970 assert_eq!(contexts.len(), 3);
1971
1972 // First message with context 1
1973 let message1_id = thread.update(cx, |thread, cx| {
1974 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
1975 });
1976
1977 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
1978 let message2_id = thread.update(cx, |thread, cx| {
1979 thread.insert_user_message(
1980 "Message 2",
1981 vec![contexts[0].clone(), contexts[1].clone()],
1982 None,
1983 cx,
1984 )
1985 });
1986
1987 // Third message with all three contexts (contexts 1 and 2 should be skipped)
1988 let message3_id = thread.update(cx, |thread, cx| {
1989 thread.insert_user_message(
1990 "Message 3",
1991 vec![
1992 contexts[0].clone(),
1993 contexts[1].clone(),
1994 contexts[2].clone(),
1995 ],
1996 None,
1997 cx,
1998 )
1999 });
2000
2001 // Check what contexts are included in each message
2002 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2003 (
2004 thread.message(message1_id).unwrap().clone(),
2005 thread.message(message2_id).unwrap().clone(),
2006 thread.message(message3_id).unwrap().clone(),
2007 )
2008 });
2009
2010 // First message should include context 1
2011 assert!(message1.context.contains("file1.rs"));
2012
2013 // Second message should include only context 2 (not 1)
2014 assert!(!message2.context.contains("file1.rs"));
2015 assert!(message2.context.contains("file2.rs"));
2016
2017 // Third message should include only context 3 (not 1 or 2)
2018 assert!(!message3.context.contains("file1.rs"));
2019 assert!(!message3.context.contains("file2.rs"));
2020 assert!(message3.context.contains("file3.rs"));
2021
2022 // Check entire request to make sure all contexts are properly included
2023 let request = thread.read_with(cx, |thread, cx| {
2024 thread.to_completion_request(RequestKind::Chat, cx)
2025 });
2026
2027 // The request should contain all 3 messages
2028 assert_eq!(request.messages.len(), 3);
2029
2030 // Check that the contexts are properly formatted in each message
2031 assert!(request.messages[0].string_contents().contains("file1.rs"));
2032 assert!(!request.messages[0].string_contents().contains("file2.rs"));
2033 assert!(!request.messages[0].string_contents().contains("file3.rs"));
2034
2035 assert!(!request.messages[1].string_contents().contains("file1.rs"));
2036 assert!(request.messages[1].string_contents().contains("file2.rs"));
2037 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2038
2039 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2040 assert!(!request.messages[2].string_contents().contains("file2.rs"));
2041 assert!(request.messages[2].string_contents().contains("file3.rs"));
2042 }
2043
2044 #[gpui::test]
2045 async fn test_message_without_files(cx: &mut TestAppContext) {
2046 init_test_settings(cx);
2047
2048 let project = create_test_project(
2049 cx,
2050 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2051 )
2052 .await;
2053
2054 let (_, _thread_store, thread, _context_store) =
2055 setup_test_environment(cx, project.clone()).await;
2056
2057 // Insert user message without any context (empty context vector)
2058 let message_id = thread.update(cx, |thread, cx| {
2059 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2060 });
2061
2062 // Check content and context in message object
2063 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2064
2065 // Context should be empty when no files are included
2066 assert_eq!(message.role, Role::User);
2067 assert_eq!(message.segments.len(), 1);
2068 assert_eq!(
2069 message.segments[0],
2070 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2071 );
2072 assert_eq!(message.context, "");
2073
2074 // Check message in request
2075 let request = thread.read_with(cx, |thread, cx| {
2076 thread.to_completion_request(RequestKind::Chat, cx)
2077 });
2078
2079 assert_eq!(request.messages.len(), 1);
2080 assert_eq!(
2081 request.messages[0].string_contents(),
2082 "What is the best way to learn Rust?"
2083 );
2084
2085 // Add second message, also without context
2086 let message2_id = thread.update(cx, |thread, cx| {
2087 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2088 });
2089
2090 let message2 =
2091 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2092 assert_eq!(message2.context, "");
2093
2094 // Check that both messages appear in the request
2095 let request = thread.read_with(cx, |thread, cx| {
2096 thread.to_completion_request(RequestKind::Chat, cx)
2097 });
2098
2099 assert_eq!(request.messages.len(), 2);
2100 assert_eq!(
2101 request.messages[0].string_contents(),
2102 "What is the best way to learn Rust?"
2103 );
2104 assert_eq!(
2105 request.messages[1].string_contents(),
2106 "Are there any good books?"
2107 );
2108 }
2109
2110 #[gpui::test]
2111 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2112 init_test_settings(cx);
2113
2114 let project = create_test_project(
2115 cx,
2116 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2117 )
2118 .await;
2119
2120 let (_workspace, _thread_store, thread, context_store) =
2121 setup_test_environment(cx, project.clone()).await;
2122
2123 // Open buffer and add it to context
2124 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2125 .await
2126 .unwrap();
2127
2128 let context =
2129 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2130
2131 // Insert user message with the buffer as context
2132 thread.update(cx, |thread, cx| {
2133 thread.insert_user_message("Explain this code", vec![context], None, cx)
2134 });
2135
2136 // Create a request and check that it doesn't have a stale buffer warning yet
2137 let initial_request = thread.read_with(cx, |thread, cx| {
2138 thread.to_completion_request(RequestKind::Chat, cx)
2139 });
2140
2141 // Make sure we don't have a stale file warning yet
2142 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2143 msg.string_contents()
2144 .contains("These files changed since last read:")
2145 });
2146 assert!(
2147 !has_stale_warning,
2148 "Should not have stale buffer warning before buffer is modified"
2149 );
2150
2151 // Modify the buffer
2152 buffer.update(cx, |buffer, cx| {
2153 // Find a position at the end of line 1
2154 buffer.edit(
2155 [(1..1, "\n println!(\"Added a new line\");\n")],
2156 None,
2157 cx,
2158 );
2159 });
2160
2161 // Insert another user message without context
2162 thread.update(cx, |thread, cx| {
2163 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2164 });
2165
2166 // Create a new request and check for the stale buffer warning
2167 let new_request = thread.read_with(cx, |thread, cx| {
2168 thread.to_completion_request(RequestKind::Chat, cx)
2169 });
2170
2171 // We should have a stale file warning as the last message
2172 let last_message = new_request
2173 .messages
2174 .last()
2175 .expect("Request should have messages");
2176
2177 // The last message should be the stale buffer notification
2178 assert_eq!(last_message.role, Role::User);
2179
2180 // Check the exact content of the message
2181 let expected_content = "These files changed since last read:\n- code.rs\n";
2182 assert_eq!(
2183 last_message.string_contents(),
2184 expected_content,
2185 "Last message should be exactly the stale buffer notification"
2186 );
2187 }
2188
2189 fn init_test_settings(cx: &mut TestAppContext) {
2190 cx.update(|cx| {
2191 let settings_store = SettingsStore::test(cx);
2192 cx.set_global(settings_store);
2193 language::init(cx);
2194 Project::init_settings(cx);
2195 AssistantSettings::register(cx);
2196 thread_store::init(cx);
2197 workspace::init_settings(cx);
2198 ThemeSettings::register(cx);
2199 ContextServerSettings::register(cx);
2200 EditorSettings::register(cx);
2201 });
2202 }
2203
2204 // Helper to create a test project with test files
2205 async fn create_test_project(
2206 cx: &mut TestAppContext,
2207 files: serde_json::Value,
2208 ) -> Entity<Project> {
2209 let fs = FakeFs::new(cx.executor());
2210 fs.insert_tree(path!("/test"), files).await;
2211 Project::test(fs, [path!("/test").as_ref()], cx).await
2212 }
2213
2214 async fn setup_test_environment(
2215 cx: &mut TestAppContext,
2216 project: Entity<Project>,
2217 ) -> (
2218 Entity<Workspace>,
2219 Entity<ThreadStore>,
2220 Entity<Thread>,
2221 Entity<ContextStore>,
2222 ) {
2223 let (workspace, cx) =
2224 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2225
2226 let thread_store = cx.update(|_, cx| {
2227 ThreadStore::new(
2228 project.clone(),
2229 Arc::default(),
2230 Arc::new(PromptBuilder::new(None).unwrap()),
2231 cx,
2232 )
2233 .unwrap()
2234 });
2235
2236 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2237 let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None));
2238
2239 (workspace, thread_store, thread, context_store)
2240 }
2241
2242 async fn add_file_to_context(
2243 project: &Entity<Project>,
2244 context_store: &Entity<ContextStore>,
2245 path: &str,
2246 cx: &mut TestAppContext,
2247 ) -> Result<Entity<language::Buffer>> {
2248 let buffer_path = project
2249 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2250 .unwrap();
2251
2252 let buffer = project
2253 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2254 .await
2255 .unwrap();
2256
2257 context_store
2258 .update(cx, |store, cx| {
2259 store.add_file_from_buffer(buffer.clone(), cx)
2260 })
2261 .await?;
2262
2263 Ok(buffer)
2264 }
2265}