1use std::fmt::Write as _;
2use std::io::Write;
3use std::sync::Arc;
4
5use anyhow::{Context as _, Result};
6use assistant_tool::{ActionLog, ToolWorkingSet};
7use chrono::{DateTime, Utc};
8use collections::{BTreeMap, HashMap, HashSet};
9use fs::Fs;
10use futures::future::Shared;
11use futures::{FutureExt, StreamExt as _};
12use git;
13use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
14use language_model::{
15 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
16 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
17 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
18 Role, StopReason, TokenUsage,
19};
20use project::git_store::{GitStore, GitStoreCheckpoint};
21use project::{Project, Worktree};
22use prompt_store::{
23 AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
24};
25use scripting_tool::{ScriptingSession, ScriptingTool};
26use serde::{Deserialize, Serialize};
27use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
28use uuid::Uuid;
29
30use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
31use crate::thread_store::{
32 SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
33};
34use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
35
36#[derive(Debug, Clone, Copy)]
37pub enum RequestKind {
38 Chat,
39 /// Used when summarizing a thread.
40 Summarize,
41}
42
43#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
44pub struct ThreadId(Arc<str>);
45
46impl ThreadId {
47 pub fn new() -> Self {
48 Self(Uuid::new_v4().to_string().into())
49 }
50}
51
52impl std::fmt::Display for ThreadId {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57
58#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
59pub struct MessageId(pub(crate) usize);
60
61impl MessageId {
62 fn post_inc(&mut self) -> Self {
63 Self(post_inc(&mut self.0))
64 }
65}
66
67/// A message in a [`Thread`].
68#[derive(Debug, Clone)]
69pub struct Message {
70 pub id: MessageId,
71 pub role: Role,
72 pub text: String,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ProjectSnapshot {
77 pub worktree_snapshots: Vec<WorktreeSnapshot>,
78 pub unsaved_buffer_paths: Vec<String>,
79 pub timestamp: DateTime<Utc>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct WorktreeSnapshot {
84 pub worktree_path: String,
85 pub git_state: Option<GitState>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct GitState {
90 pub remote_url: Option<String>,
91 pub head_sha: Option<String>,
92 pub current_branch: Option<String>,
93 pub diff: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct ThreadCheckpoint {
98 message_id: MessageId,
99 git_checkpoint: GitStoreCheckpoint,
100}
101
102#[derive(Copy, Clone, Debug)]
103pub enum ThreadFeedback {
104 Positive,
105 Negative,
106}
107
108pub enum LastRestoreCheckpoint {
109 Pending {
110 message_id: MessageId,
111 },
112 Error {
113 message_id: MessageId,
114 error: String,
115 },
116}
117
118impl LastRestoreCheckpoint {
119 pub fn message_id(&self) -> MessageId {
120 match self {
121 LastRestoreCheckpoint::Pending { message_id } => *message_id,
122 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
123 }
124 }
125}
126
127/// A thread of conversation with the LLM.
128pub struct Thread {
129 id: ThreadId,
130 updated_at: DateTime<Utc>,
131 summary: Option<SharedString>,
132 pending_summary: Task<Option<()>>,
133 messages: Vec<Message>,
134 next_message_id: MessageId,
135 context: BTreeMap<ContextId, ContextSnapshot>,
136 context_by_message: HashMap<MessageId, Vec<ContextId>>,
137 system_prompt_context: Option<AssistantSystemPromptContext>,
138 checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
139 completion_count: usize,
140 pending_completions: Vec<PendingCompletion>,
141 project: Entity<Project>,
142 prompt_builder: Arc<PromptBuilder>,
143 tools: Arc<ToolWorkingSet>,
144 tool_use: ToolUseState,
145 action_log: Entity<ActionLog>,
146 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
147 scripting_session: Entity<ScriptingSession>,
148 scripting_tool_use: ToolUseState,
149 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
150 cumulative_token_usage: TokenUsage,
151 feedback: Option<ThreadFeedback>,
152}
153
154impl Thread {
155 pub fn new(
156 project: Entity<Project>,
157 tools: Arc<ToolWorkingSet>,
158 prompt_builder: Arc<PromptBuilder>,
159 cx: &mut Context<Self>,
160 ) -> Self {
161 Self {
162 id: ThreadId::new(),
163 updated_at: Utc::now(),
164 summary: None,
165 pending_summary: Task::ready(None),
166 messages: Vec::new(),
167 next_message_id: MessageId(0),
168 context: BTreeMap::default(),
169 context_by_message: HashMap::default(),
170 system_prompt_context: None,
171 checkpoints_by_message: HashMap::default(),
172 completion_count: 0,
173 pending_completions: Vec::new(),
174 project: project.clone(),
175 prompt_builder,
176 tools: tools.clone(),
177 last_restore_checkpoint: None,
178 tool_use: ToolUseState::new(tools.clone()),
179 scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
180 scripting_tool_use: ToolUseState::new(tools),
181 action_log: cx.new(|_| ActionLog::new()),
182 initial_project_snapshot: {
183 let project_snapshot = Self::project_snapshot(project, cx);
184 cx.foreground_executor()
185 .spawn(async move { Some(project_snapshot.await) })
186 .shared()
187 },
188 cumulative_token_usage: TokenUsage::default(),
189 feedback: None,
190 }
191 }
192
193 pub fn deserialize(
194 id: ThreadId,
195 serialized: SerializedThread,
196 project: Entity<Project>,
197 tools: Arc<ToolWorkingSet>,
198 prompt_builder: Arc<PromptBuilder>,
199 cx: &mut Context<Self>,
200 ) -> Self {
201 let next_message_id = MessageId(
202 serialized
203 .messages
204 .last()
205 .map(|message| message.id.0 + 1)
206 .unwrap_or(0),
207 );
208 let tool_use =
209 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
210 name != ScriptingTool::NAME
211 });
212 let scripting_tool_use =
213 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
214 name == ScriptingTool::NAME
215 });
216 let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
217
218 Self {
219 id,
220 updated_at: serialized.updated_at,
221 summary: Some(serialized.summary),
222 pending_summary: Task::ready(None),
223 messages: serialized
224 .messages
225 .into_iter()
226 .map(|message| Message {
227 id: message.id,
228 role: message.role,
229 text: message.text,
230 })
231 .collect(),
232 next_message_id,
233 context: BTreeMap::default(),
234 context_by_message: HashMap::default(),
235 system_prompt_context: None,
236 checkpoints_by_message: HashMap::default(),
237 completion_count: 0,
238 pending_completions: Vec::new(),
239 last_restore_checkpoint: None,
240 project,
241 prompt_builder,
242 tools,
243 tool_use,
244 action_log: cx.new(|_| ActionLog::new()),
245 scripting_session,
246 scripting_tool_use,
247 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
248 // TODO: persist token usage?
249 cumulative_token_usage: TokenUsage::default(),
250 feedback: None,
251 }
252 }
253
254 pub fn id(&self) -> &ThreadId {
255 &self.id
256 }
257
258 pub fn is_empty(&self) -> bool {
259 self.messages.is_empty()
260 }
261
262 pub fn updated_at(&self) -> DateTime<Utc> {
263 self.updated_at
264 }
265
266 pub fn touch_updated_at(&mut self) {
267 self.updated_at = Utc::now();
268 }
269
270 pub fn summary(&self) -> Option<SharedString> {
271 self.summary.clone()
272 }
273
274 pub fn summary_or_default(&self) -> SharedString {
275 const DEFAULT: SharedString = SharedString::new_static("New Thread");
276 self.summary.clone().unwrap_or(DEFAULT)
277 }
278
279 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
280 self.summary = Some(summary.into());
281 cx.emit(ThreadEvent::SummaryChanged);
282 }
283
284 pub fn message(&self, id: MessageId) -> Option<&Message> {
285 self.messages.iter().find(|message| message.id == id)
286 }
287
288 pub fn messages(&self) -> impl Iterator<Item = &Message> {
289 self.messages.iter()
290 }
291
292 pub fn is_generating(&self) -> bool {
293 !self.pending_completions.is_empty() || !self.all_tools_finished()
294 }
295
296 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
297 &self.tools
298 }
299
300 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
301 let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
302 Some(ThreadCheckpoint {
303 message_id: id,
304 git_checkpoint: checkpoint,
305 })
306 }
307
308 pub fn restore_checkpoint(
309 &mut self,
310 checkpoint: ThreadCheckpoint,
311 cx: &mut Context<Self>,
312 ) -> Task<Result<()>> {
313 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
314 message_id: checkpoint.message_id,
315 });
316 cx.emit(ThreadEvent::CheckpointChanged);
317
318 let project = self.project.read(cx);
319 let restore = project
320 .git_store()
321 .read(cx)
322 .restore_checkpoint(checkpoint.git_checkpoint, cx);
323 cx.spawn(async move |this, cx| {
324 let result = restore.await;
325 this.update(cx, |this, cx| {
326 if let Err(err) = result.as_ref() {
327 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
328 message_id: checkpoint.message_id,
329 error: err.to_string(),
330 });
331 } else {
332 this.last_restore_checkpoint = None;
333 this.truncate(checkpoint.message_id, cx);
334 }
335 cx.emit(ThreadEvent::CheckpointChanged);
336 })?;
337 result
338 })
339 }
340
341 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
342 self.last_restore_checkpoint.as_ref()
343 }
344
345 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
346 let Some(message_ix) = self
347 .messages
348 .iter()
349 .rposition(|message| message.id == message_id)
350 else {
351 return;
352 };
353 for deleted_message in self.messages.drain(message_ix..) {
354 self.context_by_message.remove(&deleted_message.id);
355 self.checkpoints_by_message.remove(&deleted_message.id);
356 }
357 cx.notify();
358 }
359
360 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
361 let context = self.context_by_message.get(&id)?;
362 Some(
363 context
364 .into_iter()
365 .filter_map(|context_id| self.context.get(&context_id))
366 .cloned()
367 .collect::<Vec<_>>(),
368 )
369 }
370
371 /// Returns whether all of the tool uses have finished running.
372 pub fn all_tools_finished(&self) -> bool {
373 let mut all_pending_tool_uses = self
374 .tool_use
375 .pending_tool_uses()
376 .into_iter()
377 .chain(self.scripting_tool_use.pending_tool_uses());
378
379 // If the only pending tool uses left are the ones with errors, then
380 // that means that we've finished running all of the pending tools.
381 all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
382 }
383
384 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
385 self.tool_use.tool_uses_for_message(id, cx)
386 }
387
388 pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
389 self.scripting_tool_use.tool_uses_for_message(id, cx)
390 }
391
392 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
393 self.tool_use.tool_results_for_message(id)
394 }
395
396 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
397 self.tool_use.tool_result(id)
398 }
399
400 pub fn scripting_tool_results_for_message(
401 &self,
402 id: MessageId,
403 ) -> Vec<&LanguageModelToolResult> {
404 self.scripting_tool_use.tool_results_for_message(id)
405 }
406
407 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
408 self.tool_use.message_has_tool_results(message_id)
409 }
410
411 pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
412 self.scripting_tool_use.message_has_tool_results(message_id)
413 }
414
415 pub fn insert_user_message(
416 &mut self,
417 text: impl Into<String>,
418 context: Vec<ContextSnapshot>,
419 checkpoint: Option<GitStoreCheckpoint>,
420 cx: &mut Context<Self>,
421 ) -> MessageId {
422 let message_id = self.insert_message(Role::User, text, cx);
423 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
424 self.context
425 .extend(context.into_iter().map(|context| (context.id, context)));
426 self.context_by_message.insert(message_id, context_ids);
427 if let Some(checkpoint) = checkpoint {
428 self.checkpoints_by_message.insert(message_id, checkpoint);
429 }
430 message_id
431 }
432
433 pub fn insert_message(
434 &mut self,
435 role: Role,
436 text: impl Into<String>,
437 cx: &mut Context<Self>,
438 ) -> MessageId {
439 let id = self.next_message_id.post_inc();
440 self.messages.push(Message {
441 id,
442 role,
443 text: text.into(),
444 });
445 self.touch_updated_at();
446 cx.emit(ThreadEvent::MessageAdded(id));
447 id
448 }
449
450 pub fn edit_message(
451 &mut self,
452 id: MessageId,
453 new_role: Role,
454 new_text: String,
455 cx: &mut Context<Self>,
456 ) -> bool {
457 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
458 return false;
459 };
460 message.role = new_role;
461 message.text = new_text;
462 self.touch_updated_at();
463 cx.emit(ThreadEvent::MessageEdited(id));
464 true
465 }
466
467 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
468 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
469 return false;
470 };
471 self.messages.remove(index);
472 self.context_by_message.remove(&id);
473 self.touch_updated_at();
474 cx.emit(ThreadEvent::MessageDeleted(id));
475 true
476 }
477
478 /// Returns the representation of this [`Thread`] in a textual form.
479 ///
480 /// This is the representation we use when attaching a thread as context to another thread.
481 pub fn text(&self) -> String {
482 let mut text = String::new();
483
484 for message in &self.messages {
485 text.push_str(match message.role {
486 language_model::Role::User => "User:",
487 language_model::Role::Assistant => "Assistant:",
488 language_model::Role::System => "System:",
489 });
490 text.push('\n');
491
492 text.push_str(&message.text);
493 text.push('\n');
494 }
495
496 text
497 }
498
499 /// Serializes this thread into a format for storage or telemetry.
500 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
501 let initial_project_snapshot = self.initial_project_snapshot.clone();
502 cx.spawn(async move |this, cx| {
503 let initial_project_snapshot = initial_project_snapshot.await;
504 this.read_with(cx, |this, cx| SerializedThread {
505 summary: this.summary_or_default(),
506 updated_at: this.updated_at(),
507 messages: this
508 .messages()
509 .map(|message| SerializedMessage {
510 id: message.id,
511 role: message.role,
512 text: message.text.clone(),
513 tool_uses: this
514 .tool_uses_for_message(message.id, cx)
515 .into_iter()
516 .chain(this.scripting_tool_uses_for_message(message.id, cx))
517 .map(|tool_use| SerializedToolUse {
518 id: tool_use.id,
519 name: tool_use.name,
520 input: tool_use.input,
521 })
522 .collect(),
523 tool_results: this
524 .tool_results_for_message(message.id)
525 .into_iter()
526 .chain(this.scripting_tool_results_for_message(message.id))
527 .map(|tool_result| SerializedToolResult {
528 tool_use_id: tool_result.tool_use_id.clone(),
529 is_error: tool_result.is_error,
530 content: tool_result.content.clone(),
531 })
532 .collect(),
533 })
534 .collect(),
535 initial_project_snapshot,
536 })
537 })
538 }
539
540 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
541 self.system_prompt_context = Some(context);
542 }
543
544 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
545 &self.system_prompt_context
546 }
547
548 pub fn load_system_prompt_context(
549 &self,
550 cx: &App,
551 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
552 let project = self.project.read(cx);
553 let tasks = project
554 .visible_worktrees(cx)
555 .map(|worktree| {
556 Self::load_worktree_info_for_system_prompt(
557 project.fs().clone(),
558 worktree.read(cx),
559 cx,
560 )
561 })
562 .collect::<Vec<_>>();
563
564 cx.spawn(async |_cx| {
565 let results = futures::future::join_all(tasks).await;
566 let mut first_err = None;
567 let worktrees = results
568 .into_iter()
569 .map(|(worktree, err)| {
570 if first_err.is_none() && err.is_some() {
571 first_err = err;
572 }
573 worktree
574 })
575 .collect::<Vec<_>>();
576 (AssistantSystemPromptContext::new(worktrees), first_err)
577 })
578 }
579
580 fn load_worktree_info_for_system_prompt(
581 fs: Arc<dyn Fs>,
582 worktree: &Worktree,
583 cx: &App,
584 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
585 let root_name = worktree.root_name().into();
586 let abs_path = worktree.abs_path();
587
588 // Note that Cline supports `.clinerules` being a directory, but that is not currently
589 // supported. This doesn't seem to occur often in GitHub repositories.
590 const RULES_FILE_NAMES: [&'static str; 5] = [
591 ".rules",
592 ".cursorrules",
593 ".windsurfrules",
594 ".clinerules",
595 "CLAUDE.md",
596 ];
597 let selected_rules_file = RULES_FILE_NAMES
598 .into_iter()
599 .filter_map(|name| {
600 worktree
601 .entry_for_path(name)
602 .filter(|entry| entry.is_file())
603 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
604 })
605 .next();
606
607 if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
608 cx.spawn(async move |_| {
609 let rules_file_result = maybe!(async move {
610 let abs_rules_path = abs_rules_path?;
611 let text = fs.load(&abs_rules_path).await.with_context(|| {
612 format!("Failed to load assistant rules file {:?}", abs_rules_path)
613 })?;
614 anyhow::Ok(RulesFile {
615 rel_path: rel_rules_path,
616 abs_path: abs_rules_path.into(),
617 text: text.trim().to_string(),
618 })
619 })
620 .await;
621 let (rules_file, rules_file_error) = match rules_file_result {
622 Ok(rules_file) => (Some(rules_file), None),
623 Err(err) => (
624 None,
625 Some(ThreadError::Message {
626 header: "Error loading rules file".into(),
627 message: format!("{err}").into(),
628 }),
629 ),
630 };
631 let worktree_info = WorktreeInfoForSystemPrompt {
632 root_name,
633 abs_path,
634 rules_file,
635 };
636 (worktree_info, rules_file_error)
637 })
638 } else {
639 Task::ready((
640 WorktreeInfoForSystemPrompt {
641 root_name,
642 abs_path,
643 rules_file: None,
644 },
645 None,
646 ))
647 }
648 }
649
650 pub fn send_to_model(
651 &mut self,
652 model: Arc<dyn LanguageModel>,
653 request_kind: RequestKind,
654 cx: &mut Context<Self>,
655 ) {
656 let mut request = self.to_completion_request(request_kind, cx);
657 request.tools = {
658 let mut tools = Vec::new();
659
660 if self.tools.is_scripting_tool_enabled() {
661 tools.push(LanguageModelRequestTool {
662 name: ScriptingTool::NAME.into(),
663 description: ScriptingTool::DESCRIPTION.into(),
664 input_schema: ScriptingTool::input_schema(),
665 });
666 }
667
668 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
669 LanguageModelRequestTool {
670 name: tool.name(),
671 description: tool.description(),
672 input_schema: tool.input_schema(),
673 }
674 }));
675
676 tools
677 };
678
679 self.stream_completion(request, model, cx);
680 }
681
682 pub fn to_completion_request(
683 &self,
684 request_kind: RequestKind,
685 cx: &App,
686 ) -> LanguageModelRequest {
687 let mut request = LanguageModelRequest {
688 messages: vec![],
689 tools: Vec::new(),
690 stop: Vec::new(),
691 temperature: None,
692 };
693
694 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
695 if let Some(system_prompt) = self
696 .prompt_builder
697 .generate_assistant_system_prompt(system_prompt_context)
698 .context("failed to generate assistant system prompt")
699 .log_err()
700 {
701 request.messages.push(LanguageModelRequestMessage {
702 role: Role::System,
703 content: vec![MessageContent::Text(system_prompt)],
704 cache: true,
705 });
706 }
707 } else {
708 log::error!("system_prompt_context not set.")
709 }
710
711 let mut referenced_context_ids = HashSet::default();
712
713 for message in &self.messages {
714 if let Some(context_ids) = self.context_by_message.get(&message.id) {
715 referenced_context_ids.extend(context_ids);
716 }
717
718 let mut request_message = LanguageModelRequestMessage {
719 role: message.role,
720 content: Vec::new(),
721 cache: false,
722 };
723
724 match request_kind {
725 RequestKind::Chat => {
726 self.tool_use
727 .attach_tool_results(message.id, &mut request_message);
728 self.scripting_tool_use
729 .attach_tool_results(message.id, &mut request_message);
730 }
731 RequestKind::Summarize => {
732 // We don't care about tool use during summarization.
733 }
734 }
735
736 if !message.text.is_empty() {
737 request_message
738 .content
739 .push(MessageContent::Text(message.text.clone()));
740 }
741
742 match request_kind {
743 RequestKind::Chat => {
744 self.tool_use
745 .attach_tool_uses(message.id, &mut request_message);
746 self.scripting_tool_use
747 .attach_tool_uses(message.id, &mut request_message);
748 }
749 RequestKind::Summarize => {
750 // We don't care about tool use during summarization.
751 }
752 };
753
754 request.messages.push(request_message);
755 }
756
757 if !referenced_context_ids.is_empty() {
758 let mut context_message = LanguageModelRequestMessage {
759 role: Role::User,
760 content: Vec::new(),
761 cache: false,
762 };
763
764 let referenced_context = referenced_context_ids
765 .into_iter()
766 .filter_map(|context_id| self.context.get(context_id))
767 .cloned();
768 attach_context_to_message(&mut context_message, referenced_context);
769
770 request.messages.push(context_message);
771 }
772
773 self.attach_stale_files(&mut request.messages, cx);
774
775 request
776 }
777
778 fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
779 const STALE_FILES_HEADER: &str = "These files changed since last read:";
780
781 let mut stale_message = String::new();
782
783 for stale_file in self.action_log.read(cx).stale_buffers(cx) {
784 let Some(file) = stale_file.read(cx).file() else {
785 continue;
786 };
787
788 if stale_message.is_empty() {
789 write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
790 }
791
792 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
793 }
794
795 if !stale_message.is_empty() {
796 let context_message = LanguageModelRequestMessage {
797 role: Role::User,
798 content: vec![stale_message.into()],
799 cache: false,
800 };
801
802 messages.push(context_message);
803 }
804 }
805
806 pub fn stream_completion(
807 &mut self,
808 request: LanguageModelRequest,
809 model: Arc<dyn LanguageModel>,
810 cx: &mut Context<Self>,
811 ) {
812 let pending_completion_id = post_inc(&mut self.completion_count);
813
814 let task = cx.spawn(async move |thread, cx| {
815 let stream = model.stream_completion(request, &cx);
816 let initial_token_usage =
817 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
818 let stream_completion = async {
819 let mut events = stream.await?;
820 let mut stop_reason = StopReason::EndTurn;
821 let mut current_token_usage = TokenUsage::default();
822
823 while let Some(event) = events.next().await {
824 let event = event?;
825
826 thread.update(cx, |thread, cx| {
827 match event {
828 LanguageModelCompletionEvent::StartMessage { .. } => {
829 thread.insert_message(Role::Assistant, String::new(), cx);
830 }
831 LanguageModelCompletionEvent::Stop(reason) => {
832 stop_reason = reason;
833 }
834 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
835 thread.cumulative_token_usage =
836 thread.cumulative_token_usage.clone() + token_usage.clone()
837 - current_token_usage.clone();
838 current_token_usage = token_usage;
839 }
840 LanguageModelCompletionEvent::Text(chunk) => {
841 if let Some(last_message) = thread.messages.last_mut() {
842 if last_message.role == Role::Assistant {
843 last_message.text.push_str(&chunk);
844 cx.emit(ThreadEvent::StreamedAssistantText(
845 last_message.id,
846 chunk,
847 ));
848 } else {
849 // If we won't have an Assistant message yet, assume this chunk marks the beginning
850 // of a new Assistant response.
851 //
852 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
853 // will result in duplicating the text of the chunk in the rendered Markdown.
854 thread.insert_message(Role::Assistant, chunk, cx);
855 };
856 }
857 }
858 LanguageModelCompletionEvent::ToolUse(tool_use) => {
859 if let Some(last_assistant_message) = thread
860 .messages
861 .iter()
862 .rfind(|message| message.role == Role::Assistant)
863 {
864 if tool_use.name.as_ref() == ScriptingTool::NAME {
865 thread.scripting_tool_use.request_tool_use(
866 last_assistant_message.id,
867 tool_use,
868 cx,
869 );
870 } else {
871 thread.tool_use.request_tool_use(
872 last_assistant_message.id,
873 tool_use,
874 cx,
875 );
876 }
877 }
878 }
879 }
880
881 thread.touch_updated_at();
882 cx.emit(ThreadEvent::StreamedCompletion);
883 cx.notify();
884 })?;
885
886 smol::future::yield_now().await;
887 }
888
889 thread.update(cx, |thread, cx| {
890 thread
891 .pending_completions
892 .retain(|completion| completion.id != pending_completion_id);
893
894 if thread.summary.is_none() && thread.messages.len() >= 2 {
895 thread.summarize(cx);
896 }
897 })?;
898
899 anyhow::Ok(stop_reason)
900 };
901
902 let result = stream_completion.await;
903
904 thread
905 .update(cx, |thread, cx| {
906 match result.as_ref() {
907 Ok(stop_reason) => match stop_reason {
908 StopReason::ToolUse => {
909 cx.emit(ThreadEvent::UsePendingTools);
910 }
911 StopReason::EndTurn => {}
912 StopReason::MaxTokens => {}
913 },
914 Err(error) => {
915 if error.is::<PaymentRequiredError>() {
916 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
917 } else if error.is::<MaxMonthlySpendReachedError>() {
918 cx.emit(ThreadEvent::ShowError(
919 ThreadError::MaxMonthlySpendReached,
920 ));
921 } else {
922 let error_message = error
923 .chain()
924 .map(|err| err.to_string())
925 .collect::<Vec<_>>()
926 .join("\n");
927 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
928 header: "Error interacting with language model".into(),
929 message: SharedString::from(error_message.clone()),
930 }));
931 }
932
933 thread.cancel_last_completion(cx);
934 }
935 }
936 cx.emit(ThreadEvent::DoneStreaming);
937
938 if let Ok(initial_usage) = initial_token_usage {
939 let usage = thread.cumulative_token_usage.clone() - initial_usage;
940
941 telemetry::event!(
942 "Assistant Thread Completion",
943 thread_id = thread.id().to_string(),
944 model = model.telemetry_id(),
945 model_provider = model.provider_id().to_string(),
946 input_tokens = usage.input_tokens,
947 output_tokens = usage.output_tokens,
948 cache_creation_input_tokens = usage.cache_creation_input_tokens,
949 cache_read_input_tokens = usage.cache_read_input_tokens,
950 );
951 }
952 })
953 .ok();
954 });
955
956 self.pending_completions.push(PendingCompletion {
957 id: pending_completion_id,
958 _task: task,
959 });
960 }
961
962 pub fn summarize(&mut self, cx: &mut Context<Self>) {
963 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
964 return;
965 };
966 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
967 return;
968 };
969
970 if !provider.is_authenticated(cx) {
971 return;
972 }
973
974 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
975 request.messages.push(LanguageModelRequestMessage {
976 role: Role::User,
977 content: vec![
978 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
979 .into(),
980 ],
981 cache: false,
982 });
983
984 self.pending_summary = cx.spawn(async move |this, cx| {
985 async move {
986 let stream = model.stream_completion_text(request, &cx);
987 let mut messages = stream.await?;
988
989 let mut new_summary = String::new();
990 while let Some(message) = messages.stream.next().await {
991 let text = message?;
992 let mut lines = text.lines();
993 new_summary.extend(lines.next());
994
995 // Stop if the LLM generated multiple lines.
996 if lines.next().is_some() {
997 break;
998 }
999 }
1000
1001 this.update(cx, |this, cx| {
1002 if !new_summary.is_empty() {
1003 this.summary = Some(new_summary.into());
1004 }
1005
1006 cx.emit(ThreadEvent::SummaryChanged);
1007 })?;
1008
1009 anyhow::Ok(())
1010 }
1011 .log_err()
1012 .await
1013 });
1014 }
1015
1016 pub fn use_pending_tools(
1017 &mut self,
1018 cx: &mut Context<Self>,
1019 ) -> impl IntoIterator<Item = PendingToolUse> {
1020 let request = self.to_completion_request(RequestKind::Chat, cx);
1021 let pending_tool_uses = self
1022 .tool_use
1023 .pending_tool_uses()
1024 .into_iter()
1025 .filter(|tool_use| tool_use.status.is_idle())
1026 .cloned()
1027 .collect::<Vec<_>>();
1028
1029 for tool_use in pending_tool_uses.iter() {
1030 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1031 let task = tool.run(
1032 tool_use.input.clone(),
1033 &request.messages,
1034 self.project.clone(),
1035 self.action_log.clone(),
1036 cx,
1037 );
1038
1039 self.insert_tool_output(
1040 tool_use.id.clone(),
1041 tool_use.ui_text.clone().into(),
1042 task,
1043 cx,
1044 );
1045 }
1046 }
1047
1048 let pending_scripting_tool_uses = self
1049 .scripting_tool_use
1050 .pending_tool_uses()
1051 .into_iter()
1052 .filter(|tool_use| tool_use.status.is_idle())
1053 .cloned()
1054 .collect::<Vec<_>>();
1055
1056 for scripting_tool_use in pending_scripting_tool_uses.iter() {
1057 let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
1058 Err(err) => Task::ready(Err(err.into())),
1059 Ok(input) => {
1060 let (script_id, script_task) =
1061 self.scripting_session.update(cx, move |session, cx| {
1062 session.run_script(input.lua_script, cx)
1063 });
1064
1065 let session = self.scripting_session.clone();
1066 cx.spawn(async move |_, cx| {
1067 script_task.await;
1068
1069 let message = session.read_with(cx, |session, _cx| {
1070 // Using a id to get the script output seems impractical.
1071 // Why not just include it in the Task result?
1072 // This is because we'll later report the script state as it runs,
1073 session
1074 .get(script_id)
1075 .output_message_for_llm()
1076 .expect("Script shouldn't still be running")
1077 })?;
1078
1079 Ok(message)
1080 })
1081 }
1082 };
1083
1084 let ui_text: SharedString = scripting_tool_use.name.clone().into();
1085
1086 self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
1087 }
1088
1089 pending_tool_uses
1090 .into_iter()
1091 .chain(pending_scripting_tool_uses)
1092 }
1093
1094 pub fn insert_tool_output(
1095 &mut self,
1096 tool_use_id: LanguageModelToolUseId,
1097 ui_text: SharedString,
1098 output: Task<Result<String>>,
1099 cx: &mut Context<Self>,
1100 ) {
1101 let insert_output_task = cx.spawn({
1102 let tool_use_id = tool_use_id.clone();
1103 async move |thread, cx| {
1104 let output = output.await;
1105 thread
1106 .update(cx, |thread, cx| {
1107 let pending_tool_use = thread
1108 .tool_use
1109 .insert_tool_output(tool_use_id.clone(), output);
1110
1111 cx.emit(ThreadEvent::ToolFinished {
1112 tool_use_id,
1113 pending_tool_use,
1114 canceled: false,
1115 });
1116 })
1117 .ok();
1118 }
1119 });
1120
1121 self.tool_use
1122 .run_pending_tool(tool_use_id, ui_text, insert_output_task);
1123 }
1124
1125 pub fn insert_scripting_tool_output(
1126 &mut self,
1127 tool_use_id: LanguageModelToolUseId,
1128 ui_text: SharedString,
1129 output: Task<Result<String>>,
1130 cx: &mut Context<Self>,
1131 ) {
1132 let insert_output_task = cx.spawn({
1133 let tool_use_id = tool_use_id.clone();
1134 async move |thread, cx| {
1135 let output = output.await;
1136 thread
1137 .update(cx, |thread, cx| {
1138 let pending_tool_use = thread
1139 .scripting_tool_use
1140 .insert_tool_output(tool_use_id.clone(), output);
1141
1142 cx.emit(ThreadEvent::ToolFinished {
1143 tool_use_id,
1144 pending_tool_use,
1145 canceled: false,
1146 });
1147 })
1148 .ok();
1149 }
1150 });
1151
1152 self.scripting_tool_use
1153 .run_pending_tool(tool_use_id, ui_text, insert_output_task);
1154 }
1155
1156 pub fn attach_tool_results(
1157 &mut self,
1158 updated_context: Vec<ContextSnapshot>,
1159 cx: &mut Context<Self>,
1160 ) {
1161 self.context.extend(
1162 updated_context
1163 .into_iter()
1164 .map(|context| (context.id, context)),
1165 );
1166
1167 // Insert a user message to contain the tool results.
1168 self.insert_user_message(
1169 // TODO: Sending up a user message without any content results in the model sending back
1170 // responses that also don't have any content. We currently don't handle this case well,
1171 // so for now we provide some text to keep the model on track.
1172 "Here are the tool results.",
1173 Vec::new(),
1174 None,
1175 cx,
1176 );
1177 }
1178
1179 /// Cancels the last pending completion, if there are any pending.
1180 ///
1181 /// Returns whether a completion was canceled.
1182 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1183 if self.pending_completions.pop().is_some() {
1184 true
1185 } else {
1186 let mut canceled = false;
1187 for pending_tool_use in self.tool_use.cancel_pending() {
1188 canceled = true;
1189 cx.emit(ThreadEvent::ToolFinished {
1190 tool_use_id: pending_tool_use.id.clone(),
1191 pending_tool_use: Some(pending_tool_use),
1192 canceled: true,
1193 });
1194 }
1195 canceled
1196 }
1197 }
1198
1199 /// Returns the feedback given to the thread, if any.
1200 pub fn feedback(&self) -> Option<ThreadFeedback> {
1201 self.feedback
1202 }
1203
1204 /// Reports feedback about the thread and stores it in our telemetry backend.
1205 pub fn report_feedback(
1206 &mut self,
1207 feedback: ThreadFeedback,
1208 cx: &mut Context<Self>,
1209 ) -> Task<Result<()>> {
1210 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1211 let serialized_thread = self.serialize(cx);
1212 let thread_id = self.id().clone();
1213 let client = self.project.read(cx).client();
1214 self.feedback = Some(feedback);
1215 cx.notify();
1216
1217 cx.background_spawn(async move {
1218 let final_project_snapshot = final_project_snapshot.await;
1219 let serialized_thread = serialized_thread.await?;
1220 let thread_data =
1221 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1222
1223 let rating = match feedback {
1224 ThreadFeedback::Positive => "positive",
1225 ThreadFeedback::Negative => "negative",
1226 };
1227 telemetry::event!(
1228 "Assistant Thread Rated",
1229 rating,
1230 thread_id,
1231 thread_data,
1232 final_project_snapshot
1233 );
1234 client.telemetry().flush_events();
1235
1236 Ok(())
1237 })
1238 }
1239
1240 /// Create a snapshot of the current project state including git information and unsaved buffers.
1241 fn project_snapshot(
1242 project: Entity<Project>,
1243 cx: &mut Context<Self>,
1244 ) -> Task<Arc<ProjectSnapshot>> {
1245 let git_store = project.read(cx).git_store().clone();
1246 let worktree_snapshots: Vec<_> = project
1247 .read(cx)
1248 .visible_worktrees(cx)
1249 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1250 .collect();
1251
1252 cx.spawn(async move |_, cx| {
1253 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1254
1255 let mut unsaved_buffers = Vec::new();
1256 cx.update(|app_cx| {
1257 let buffer_store = project.read(app_cx).buffer_store();
1258 for buffer_handle in buffer_store.read(app_cx).buffers() {
1259 let buffer = buffer_handle.read(app_cx);
1260 if buffer.is_dirty() {
1261 if let Some(file) = buffer.file() {
1262 let path = file.path().to_string_lossy().to_string();
1263 unsaved_buffers.push(path);
1264 }
1265 }
1266 }
1267 })
1268 .ok();
1269
1270 Arc::new(ProjectSnapshot {
1271 worktree_snapshots,
1272 unsaved_buffer_paths: unsaved_buffers,
1273 timestamp: Utc::now(),
1274 })
1275 })
1276 }
1277
1278 fn worktree_snapshot(
1279 worktree: Entity<project::Worktree>,
1280 git_store: Entity<GitStore>,
1281 cx: &App,
1282 ) -> Task<WorktreeSnapshot> {
1283 cx.spawn(async move |cx| {
1284 // Get worktree path and snapshot
1285 let worktree_info = cx.update(|app_cx| {
1286 let worktree = worktree.read(app_cx);
1287 let path = worktree.abs_path().to_string_lossy().to_string();
1288 let snapshot = worktree.snapshot();
1289 (path, snapshot)
1290 });
1291
1292 let Ok((worktree_path, snapshot)) = worktree_info else {
1293 return WorktreeSnapshot {
1294 worktree_path: String::new(),
1295 git_state: None,
1296 };
1297 };
1298
1299 let repo_info = git_store
1300 .update(cx, |git_store, cx| {
1301 git_store
1302 .repositories()
1303 .values()
1304 .find(|repo| repo.read(cx).worktree_id == snapshot.id())
1305 .and_then(|repo| {
1306 let repo = repo.read(cx);
1307 Some((repo.branch().cloned(), repo.local_repository()?))
1308 })
1309 })
1310 .ok()
1311 .flatten();
1312
1313 // Extract git information
1314 let git_state = match repo_info {
1315 None => None,
1316 Some((branch, repo)) => {
1317 let current_branch = branch.map(|branch| branch.name.to_string());
1318 let remote_url = repo.remote_url("origin");
1319 let head_sha = repo.head_sha();
1320
1321 // Get diff asynchronously
1322 let diff = repo
1323 .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1324 .await
1325 .ok();
1326
1327 Some(GitState {
1328 remote_url,
1329 head_sha,
1330 current_branch,
1331 diff,
1332 })
1333 }
1334 };
1335
1336 WorktreeSnapshot {
1337 worktree_path,
1338 git_state,
1339 }
1340 })
1341 }
1342
1343 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1344 let mut markdown = Vec::new();
1345
1346 if let Some(summary) = self.summary() {
1347 writeln!(markdown, "# {summary}\n")?;
1348 };
1349
1350 for message in self.messages() {
1351 writeln!(
1352 markdown,
1353 "## {role}\n",
1354 role = match message.role {
1355 Role::User => "User",
1356 Role::Assistant => "Assistant",
1357 Role::System => "System",
1358 }
1359 )?;
1360 writeln!(markdown, "{}\n", message.text)?;
1361
1362 for tool_use in self.tool_uses_for_message(message.id, cx) {
1363 writeln!(
1364 markdown,
1365 "**Use Tool: {} ({})**",
1366 tool_use.name, tool_use.id
1367 )?;
1368 writeln!(markdown, "```json")?;
1369 writeln!(
1370 markdown,
1371 "{}",
1372 serde_json::to_string_pretty(&tool_use.input)?
1373 )?;
1374 writeln!(markdown, "```")?;
1375 }
1376
1377 for tool_result in self.tool_results_for_message(message.id) {
1378 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1379 if tool_result.is_error {
1380 write!(markdown, " (Error)")?;
1381 }
1382
1383 writeln!(markdown, "**\n")?;
1384 writeln!(markdown, "{}", tool_result.content)?;
1385 }
1386 }
1387
1388 Ok(String::from_utf8_lossy(&markdown).to_string())
1389 }
1390
1391 pub fn action_log(&self) -> &Entity<ActionLog> {
1392 &self.action_log
1393 }
1394
1395 pub fn project(&self) -> &Entity<Project> {
1396 &self.project
1397 }
1398
1399 pub fn cumulative_token_usage(&self) -> TokenUsage {
1400 self.cumulative_token_usage.clone()
1401 }
1402}
1403
1404#[derive(Debug, Clone)]
1405pub enum ThreadError {
1406 PaymentRequired,
1407 MaxMonthlySpendReached,
1408 Message {
1409 header: SharedString,
1410 message: SharedString,
1411 },
1412}
1413
1414#[derive(Debug, Clone)]
1415pub enum ThreadEvent {
1416 ShowError(ThreadError),
1417 StreamedCompletion,
1418 StreamedAssistantText(MessageId, String),
1419 DoneStreaming,
1420 MessageAdded(MessageId),
1421 MessageEdited(MessageId),
1422 MessageDeleted(MessageId),
1423 SummaryChanged,
1424 UsePendingTools,
1425 ToolFinished {
1426 #[allow(unused)]
1427 tool_use_id: LanguageModelToolUseId,
1428 /// The pending tool use that corresponds to this tool.
1429 pending_tool_use: Option<PendingToolUse>,
1430 /// Whether the tool was canceled by the user.
1431 canceled: bool,
1432 },
1433 CheckpointChanged,
1434}
1435
1436impl EventEmitter<ThreadEvent> for Thread {}
1437
1438struct PendingCompletion {
1439 id: usize,
1440 _task: Task<()>,
1441}