1use std::io::Write;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use assistant_tool::{ActionLog, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::{BTreeMap, HashMap, HashSet};
8use futures::future::Shared;
9use futures::{FutureExt, StreamExt as _};
10use git;
11use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
12use language_model::{
13 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
14 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
15 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
16 Role, StopReason, TokenUsage,
17};
18use project::Project;
19use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
20use scripting_tool::{ScriptingSession, ScriptingTool};
21use serde::{Deserialize, Serialize};
22use util::{post_inc, ResultExt, TryFutureExt as _};
23use uuid::Uuid;
24
25use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
26use crate::thread_store::{
27 SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
28};
29use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
30
31#[derive(Debug, Clone, Copy)]
32pub enum RequestKind {
33 Chat,
34 /// Used when summarizing a thread.
35 Summarize,
36}
37
38#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
39pub struct ThreadId(Arc<str>);
40
41impl ThreadId {
42 pub fn new() -> Self {
43 Self(Uuid::new_v4().to_string().into())
44 }
45}
46
47impl std::fmt::Display for ThreadId {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{}", self.0)
50 }
51}
52
53#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
54pub struct MessageId(pub(crate) usize);
55
56impl MessageId {
57 fn post_inc(&mut self) -> Self {
58 Self(post_inc(&mut self.0))
59 }
60}
61
62/// A message in a [`Thread`].
63#[derive(Debug, Clone)]
64pub struct Message {
65 pub id: MessageId,
66 pub role: Role,
67 pub text: String,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ProjectSnapshot {
72 pub worktree_snapshots: Vec<WorktreeSnapshot>,
73 pub unsaved_buffer_paths: Vec<String>,
74 pub timestamp: DateTime<Utc>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct WorktreeSnapshot {
79 pub worktree_path: String,
80 pub git_state: Option<GitState>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct GitState {
85 pub remote_url: Option<String>,
86 pub head_sha: Option<String>,
87 pub current_branch: Option<String>,
88 pub diff: Option<String>,
89}
90
91/// A thread of conversation with the LLM.
92pub struct Thread {
93 id: ThreadId,
94 updated_at: DateTime<Utc>,
95 summary: Option<SharedString>,
96 pending_summary: Task<Option<()>>,
97 messages: Vec<Message>,
98 next_message_id: MessageId,
99 context: BTreeMap<ContextId, ContextSnapshot>,
100 context_by_message: HashMap<MessageId, Vec<ContextId>>,
101 completion_count: usize,
102 pending_completions: Vec<PendingCompletion>,
103 project: Entity<Project>,
104 prompt_builder: Arc<PromptBuilder>,
105 tools: Arc<ToolWorkingSet>,
106 tool_use: ToolUseState,
107 action_log: Entity<ActionLog>,
108 scripting_session: Entity<ScriptingSession>,
109 scripting_tool_use: ToolUseState,
110 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
111 cumulative_token_usage: TokenUsage,
112}
113
114impl Thread {
115 pub fn new(
116 project: Entity<Project>,
117 tools: Arc<ToolWorkingSet>,
118 prompt_builder: Arc<PromptBuilder>,
119 cx: &mut Context<Self>,
120 ) -> Self {
121 Self {
122 id: ThreadId::new(),
123 updated_at: Utc::now(),
124 summary: None,
125 pending_summary: Task::ready(None),
126 messages: Vec::new(),
127 next_message_id: MessageId(0),
128 context: BTreeMap::default(),
129 context_by_message: HashMap::default(),
130 completion_count: 0,
131 pending_completions: Vec::new(),
132 project: project.clone(),
133 prompt_builder,
134 tools,
135 tool_use: ToolUseState::new(),
136 scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
137 scripting_tool_use: ToolUseState::new(),
138 action_log: cx.new(|_| ActionLog::new()),
139 initial_project_snapshot: {
140 let project_snapshot = Self::project_snapshot(project, cx);
141 cx.foreground_executor()
142 .spawn(async move { Some(project_snapshot.await) })
143 .shared()
144 },
145 cumulative_token_usage: TokenUsage::default(),
146 }
147 }
148
149 pub fn deserialize(
150 id: ThreadId,
151 serialized: SerializedThread,
152 project: Entity<Project>,
153 tools: Arc<ToolWorkingSet>,
154 prompt_builder: Arc<PromptBuilder>,
155 cx: &mut Context<Self>,
156 ) -> Self {
157 let next_message_id = MessageId(
158 serialized
159 .messages
160 .last()
161 .map(|message| message.id.0 + 1)
162 .unwrap_or(0),
163 );
164 let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
165 name != ScriptingTool::NAME
166 });
167 let scripting_tool_use =
168 ToolUseState::from_serialized_messages(&serialized.messages, |name| {
169 name == ScriptingTool::NAME
170 });
171 let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
172
173 Self {
174 id,
175 updated_at: serialized.updated_at,
176 summary: Some(serialized.summary),
177 pending_summary: Task::ready(None),
178 messages: serialized
179 .messages
180 .into_iter()
181 .map(|message| Message {
182 id: message.id,
183 role: message.role,
184 text: message.text,
185 })
186 .collect(),
187 next_message_id,
188 context: BTreeMap::default(),
189 context_by_message: HashMap::default(),
190 completion_count: 0,
191 pending_completions: Vec::new(),
192 project,
193 prompt_builder,
194 tools,
195 tool_use,
196 action_log: cx.new(|_| ActionLog::new()),
197 scripting_session,
198 scripting_tool_use,
199 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
200 // TODO: persist token usage?
201 cumulative_token_usage: TokenUsage::default(),
202 }
203 }
204
205 pub fn id(&self) -> &ThreadId {
206 &self.id
207 }
208
209 pub fn is_empty(&self) -> bool {
210 self.messages.is_empty()
211 }
212
213 pub fn updated_at(&self) -> DateTime<Utc> {
214 self.updated_at
215 }
216
217 pub fn touch_updated_at(&mut self) {
218 self.updated_at = Utc::now();
219 }
220
221 pub fn summary(&self) -> Option<SharedString> {
222 self.summary.clone()
223 }
224
225 pub fn summary_or_default(&self) -> SharedString {
226 const DEFAULT: SharedString = SharedString::new_static("New Thread");
227 self.summary.clone().unwrap_or(DEFAULT)
228 }
229
230 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
231 self.summary = Some(summary.into());
232 cx.emit(ThreadEvent::SummaryChanged);
233 }
234
235 pub fn message(&self, id: MessageId) -> Option<&Message> {
236 self.messages.iter().find(|message| message.id == id)
237 }
238
239 pub fn messages(&self) -> impl Iterator<Item = &Message> {
240 self.messages.iter()
241 }
242
243 pub fn is_streaming(&self) -> bool {
244 !self.pending_completions.is_empty() || !self.all_tools_finished()
245 }
246
247 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
248 &self.tools
249 }
250
251 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
252 let context = self.context_by_message.get(&id)?;
253 Some(
254 context
255 .into_iter()
256 .filter_map(|context_id| self.context.get(&context_id))
257 .cloned()
258 .collect::<Vec<_>>(),
259 )
260 }
261
262 /// Returns whether all of the tool uses have finished running.
263 pub fn all_tools_finished(&self) -> bool {
264 let mut all_pending_tool_uses = self
265 .tool_use
266 .pending_tool_uses()
267 .into_iter()
268 .chain(self.scripting_tool_use.pending_tool_uses());
269
270 // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
271 // of the pending tools.
272 all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
273 }
274
275 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
276 self.tool_use.tool_uses_for_message(id)
277 }
278
279 pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
280 self.scripting_tool_use.tool_uses_for_message(id)
281 }
282
283 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
284 self.tool_use.tool_results_for_message(id)
285 }
286
287 pub fn scripting_tool_results_for_message(
288 &self,
289 id: MessageId,
290 ) -> Vec<&LanguageModelToolResult> {
291 self.scripting_tool_use.tool_results_for_message(id)
292 }
293
294 pub fn scripting_changed_buffers<'a>(
295 &self,
296 cx: &'a App,
297 ) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
298 self.scripting_session.read(cx).changed_buffers()
299 }
300
301 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
302 self.tool_use.message_has_tool_results(message_id)
303 }
304
305 pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
306 self.scripting_tool_use.message_has_tool_results(message_id)
307 }
308
309 pub fn insert_user_message(
310 &mut self,
311 text: impl Into<String>,
312 context: Vec<ContextSnapshot>,
313 cx: &mut Context<Self>,
314 ) -> MessageId {
315 let message_id = self.insert_message(Role::User, text, cx);
316 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
317 self.context
318 .extend(context.into_iter().map(|context| (context.id, context)));
319 self.context_by_message.insert(message_id, context_ids);
320 message_id
321 }
322
323 pub fn insert_message(
324 &mut self,
325 role: Role,
326 text: impl Into<String>,
327 cx: &mut Context<Self>,
328 ) -> MessageId {
329 let id = self.next_message_id.post_inc();
330 self.messages.push(Message {
331 id,
332 role,
333 text: text.into(),
334 });
335 self.touch_updated_at();
336 cx.emit(ThreadEvent::MessageAdded(id));
337 id
338 }
339
340 pub fn edit_message(
341 &mut self,
342 id: MessageId,
343 new_role: Role,
344 new_text: String,
345 cx: &mut Context<Self>,
346 ) -> bool {
347 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
348 return false;
349 };
350 message.role = new_role;
351 message.text = new_text;
352 self.touch_updated_at();
353 cx.emit(ThreadEvent::MessageEdited(id));
354 true
355 }
356
357 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
358 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
359 return false;
360 };
361 self.messages.remove(index);
362 self.context_by_message.remove(&id);
363 self.touch_updated_at();
364 cx.emit(ThreadEvent::MessageDeleted(id));
365 true
366 }
367
368 /// Returns the representation of this [`Thread`] in a textual form.
369 ///
370 /// This is the representation we use when attaching a thread as context to another thread.
371 pub fn text(&self) -> String {
372 let mut text = String::new();
373
374 for message in &self.messages {
375 text.push_str(match message.role {
376 language_model::Role::User => "User:",
377 language_model::Role::Assistant => "Assistant:",
378 language_model::Role::System => "System:",
379 });
380 text.push('\n');
381
382 text.push_str(&message.text);
383 text.push('\n');
384 }
385
386 text
387 }
388
389 /// Serializes this thread into a format for storage or telemetry.
390 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
391 let initial_project_snapshot = self.initial_project_snapshot.clone();
392 cx.spawn(|this, cx| async move {
393 let initial_project_snapshot = initial_project_snapshot.await;
394 this.read_with(&cx, |this, _| SerializedThread {
395 summary: this.summary_or_default(),
396 updated_at: this.updated_at(),
397 messages: this
398 .messages()
399 .map(|message| SerializedMessage {
400 id: message.id,
401 role: message.role,
402 text: message.text.clone(),
403 tool_uses: this
404 .tool_uses_for_message(message.id)
405 .into_iter()
406 .chain(this.scripting_tool_uses_for_message(message.id))
407 .map(|tool_use| SerializedToolUse {
408 id: tool_use.id,
409 name: tool_use.name,
410 input: tool_use.input,
411 })
412 .collect(),
413 tool_results: this
414 .tool_results_for_message(message.id)
415 .into_iter()
416 .chain(this.scripting_tool_results_for_message(message.id))
417 .map(|tool_result| SerializedToolResult {
418 tool_use_id: tool_result.tool_use_id.clone(),
419 is_error: tool_result.is_error,
420 content: tool_result.content.clone(),
421 })
422 .collect(),
423 })
424 .collect(),
425 initial_project_snapshot,
426 })
427 })
428 }
429
430 pub fn send_to_model(
431 &mut self,
432 model: Arc<dyn LanguageModel>,
433 request_kind: RequestKind,
434 cx: &mut Context<Self>,
435 ) {
436 let mut request = self.to_completion_request(request_kind, cx);
437 request.tools = {
438 let mut tools = Vec::new();
439
440 if self.tools.is_scripting_tool_enabled() {
441 tools.push(LanguageModelRequestTool {
442 name: ScriptingTool::NAME.into(),
443 description: ScriptingTool::DESCRIPTION.into(),
444 input_schema: ScriptingTool::input_schema(),
445 });
446 }
447
448 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
449 LanguageModelRequestTool {
450 name: tool.name(),
451 description: tool.description(),
452 input_schema: tool.input_schema(),
453 }
454 }));
455
456 tools
457 };
458
459 self.stream_completion(request, model, cx);
460 }
461
462 pub fn to_completion_request(
463 &self,
464 request_kind: RequestKind,
465 cx: &App,
466 ) -> LanguageModelRequest {
467 let worktree_root_names = self
468 .project
469 .read(cx)
470 .visible_worktrees(cx)
471 .map(|worktree| {
472 let worktree = worktree.read(cx);
473 AssistantSystemPromptWorktree {
474 root_name: worktree.root_name().into(),
475 abs_path: worktree.abs_path(),
476 }
477 })
478 .collect::<Vec<_>>();
479 let system_prompt = self
480 .prompt_builder
481 .generate_assistant_system_prompt(worktree_root_names)
482 .context("failed to generate assistant system prompt")
483 .log_err()
484 .unwrap_or_default();
485
486 let mut request = LanguageModelRequest {
487 messages: vec![LanguageModelRequestMessage {
488 role: Role::System,
489 content: vec![MessageContent::Text(system_prompt)],
490 cache: true,
491 }],
492 tools: Vec::new(),
493 stop: Vec::new(),
494 temperature: None,
495 };
496
497 let mut referenced_context_ids = HashSet::default();
498
499 for message in &self.messages {
500 if let Some(context_ids) = self.context_by_message.get(&message.id) {
501 referenced_context_ids.extend(context_ids);
502 }
503
504 let mut request_message = LanguageModelRequestMessage {
505 role: message.role,
506 content: Vec::new(),
507 cache: false,
508 };
509
510 match request_kind {
511 RequestKind::Chat => {
512 self.tool_use
513 .attach_tool_results(message.id, &mut request_message);
514 self.scripting_tool_use
515 .attach_tool_results(message.id, &mut request_message);
516 }
517 RequestKind::Summarize => {
518 // We don't care about tool use during summarization.
519 }
520 }
521
522 if !message.text.is_empty() {
523 request_message
524 .content
525 .push(MessageContent::Text(message.text.clone()));
526 }
527
528 match request_kind {
529 RequestKind::Chat => {
530 self.tool_use
531 .attach_tool_uses(message.id, &mut request_message);
532 self.scripting_tool_use
533 .attach_tool_uses(message.id, &mut request_message);
534 }
535 RequestKind::Summarize => {
536 // We don't care about tool use during summarization.
537 }
538 };
539
540 request.messages.push(request_message);
541 }
542
543 if !referenced_context_ids.is_empty() {
544 let mut context_message = LanguageModelRequestMessage {
545 role: Role::User,
546 content: Vec::new(),
547 cache: false,
548 };
549
550 let referenced_context = referenced_context_ids
551 .into_iter()
552 .filter_map(|context_id| self.context.get(context_id))
553 .cloned();
554 attach_context_to_message(&mut context_message, referenced_context);
555
556 request.messages.push(context_message);
557 }
558
559 request
560 }
561
562 pub fn stream_completion(
563 &mut self,
564 request: LanguageModelRequest,
565 model: Arc<dyn LanguageModel>,
566 cx: &mut Context<Self>,
567 ) {
568 let pending_completion_id = post_inc(&mut self.completion_count);
569
570 let task = cx.spawn(|thread, mut cx| async move {
571 let stream = model.stream_completion(request, &cx);
572 let stream_completion = async {
573 let mut events = stream.await?;
574 let mut stop_reason = StopReason::EndTurn;
575 let mut current_token_usage = TokenUsage::default();
576
577 while let Some(event) = events.next().await {
578 let event = event?;
579
580 thread.update(&mut cx, |thread, cx| {
581 match event {
582 LanguageModelCompletionEvent::StartMessage { .. } => {
583 thread.insert_message(Role::Assistant, String::new(), cx);
584 }
585 LanguageModelCompletionEvent::Stop(reason) => {
586 stop_reason = reason;
587 }
588 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
589 thread.cumulative_token_usage =
590 thread.cumulative_token_usage.clone() + token_usage.clone()
591 - current_token_usage.clone();
592 current_token_usage = token_usage;
593 }
594 LanguageModelCompletionEvent::Text(chunk) => {
595 if let Some(last_message) = thread.messages.last_mut() {
596 if last_message.role == Role::Assistant {
597 last_message.text.push_str(&chunk);
598 cx.emit(ThreadEvent::StreamedAssistantText(
599 last_message.id,
600 chunk,
601 ));
602 } else {
603 // If we won't have an Assistant message yet, assume this chunk marks the beginning
604 // of a new Assistant response.
605 //
606 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
607 // will result in duplicating the text of the chunk in the rendered Markdown.
608 thread.insert_message(Role::Assistant, chunk, cx);
609 };
610 }
611 }
612 LanguageModelCompletionEvent::ToolUse(tool_use) => {
613 if let Some(last_assistant_message) = thread
614 .messages
615 .iter()
616 .rfind(|message| message.role == Role::Assistant)
617 {
618 if tool_use.name.as_ref() == ScriptingTool::NAME {
619 thread
620 .scripting_tool_use
621 .request_tool_use(last_assistant_message.id, tool_use);
622 } else {
623 thread
624 .tool_use
625 .request_tool_use(last_assistant_message.id, tool_use);
626 }
627 }
628 }
629 }
630
631 thread.touch_updated_at();
632 cx.emit(ThreadEvent::StreamedCompletion);
633 cx.notify();
634 })?;
635
636 smol::future::yield_now().await;
637 }
638
639 thread.update(&mut cx, |thread, cx| {
640 thread
641 .pending_completions
642 .retain(|completion| completion.id != pending_completion_id);
643
644 if thread.summary.is_none() && thread.messages.len() >= 2 {
645 thread.summarize(cx);
646 }
647 })?;
648
649 anyhow::Ok(stop_reason)
650 };
651
652 let result = stream_completion.await;
653
654 thread
655 .update(&mut cx, |thread, cx| match result.as_ref() {
656 Ok(stop_reason) => match stop_reason {
657 StopReason::ToolUse => {
658 cx.emit(ThreadEvent::UsePendingTools);
659 }
660 StopReason::EndTurn => {}
661 StopReason::MaxTokens => {}
662 },
663 Err(error) => {
664 if error.is::<PaymentRequiredError>() {
665 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
666 } else if error.is::<MaxMonthlySpendReachedError>() {
667 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
668 } else {
669 let error_message = error
670 .chain()
671 .map(|err| err.to_string())
672 .collect::<Vec<_>>()
673 .join("\n");
674 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
675 SharedString::from(error_message.clone()),
676 )));
677 }
678
679 thread.cancel_last_completion();
680 }
681 })
682 .ok();
683 });
684
685 self.pending_completions.push(PendingCompletion {
686 id: pending_completion_id,
687 _task: task,
688 });
689 }
690
691 pub fn summarize(&mut self, cx: &mut Context<Self>) {
692 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
693 return;
694 };
695 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
696 return;
697 };
698
699 if !provider.is_authenticated(cx) {
700 return;
701 }
702
703 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
704 request.messages.push(LanguageModelRequestMessage {
705 role: Role::User,
706 content: vec![
707 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
708 .into(),
709 ],
710 cache: false,
711 });
712
713 self.pending_summary = cx.spawn(|this, mut cx| {
714 async move {
715 let stream = model.stream_completion_text(request, &cx);
716 let mut messages = stream.await?;
717
718 let mut new_summary = String::new();
719 while let Some(message) = messages.stream.next().await {
720 let text = message?;
721 let mut lines = text.lines();
722 new_summary.extend(lines.next());
723
724 // Stop if the LLM generated multiple lines.
725 if lines.next().is_some() {
726 break;
727 }
728 }
729
730 this.update(&mut cx, |this, cx| {
731 if !new_summary.is_empty() {
732 this.summary = Some(new_summary.into());
733 }
734
735 cx.emit(ThreadEvent::SummaryChanged);
736 })?;
737
738 anyhow::Ok(())
739 }
740 .log_err()
741 });
742 }
743
744 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
745 let request = self.to_completion_request(RequestKind::Chat, cx);
746 let pending_tool_uses = self
747 .tool_use
748 .pending_tool_uses()
749 .into_iter()
750 .filter(|tool_use| tool_use.status.is_idle())
751 .cloned()
752 .collect::<Vec<_>>();
753
754 for tool_use in pending_tool_uses {
755 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
756 let task = tool.run(
757 tool_use.input,
758 &request.messages,
759 self.project.clone(),
760 self.action_log.clone(),
761 cx,
762 );
763
764 self.insert_tool_output(tool_use.id.clone(), task, cx);
765 }
766 }
767
768 let pending_scripting_tool_uses = self
769 .scripting_tool_use
770 .pending_tool_uses()
771 .into_iter()
772 .filter(|tool_use| tool_use.status.is_idle())
773 .cloned()
774 .collect::<Vec<_>>();
775
776 for scripting_tool_use in pending_scripting_tool_uses {
777 let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
778 Err(err) => Task::ready(Err(err.into())),
779 Ok(input) => {
780 let (script_id, script_task) =
781 self.scripting_session.update(cx, move |session, cx| {
782 session.run_script(input.lua_script, cx)
783 });
784
785 let session = self.scripting_session.clone();
786 cx.spawn(|_, cx| async move {
787 script_task.await;
788
789 let message = session.read_with(&cx, |session, _cx| {
790 // Using a id to get the script output seems impractical.
791 // Why not just include it in the Task result?
792 // This is because we'll later report the script state as it runs,
793 session
794 .get(script_id)
795 .output_message_for_llm()
796 .expect("Script shouldn't still be running")
797 })?;
798
799 Ok(message)
800 })
801 }
802 };
803
804 self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
805 }
806 }
807
808 pub fn insert_tool_output(
809 &mut self,
810 tool_use_id: LanguageModelToolUseId,
811 output: Task<Result<String>>,
812 cx: &mut Context<Self>,
813 ) {
814 let insert_output_task = cx.spawn(|thread, mut cx| {
815 let tool_use_id = tool_use_id.clone();
816 async move {
817 let output = output.await;
818 thread
819 .update(&mut cx, |thread, cx| {
820 let pending_tool_use = thread
821 .tool_use
822 .insert_tool_output(tool_use_id.clone(), output);
823
824 cx.emit(ThreadEvent::ToolFinished {
825 tool_use_id,
826 pending_tool_use,
827 });
828 })
829 .ok();
830 }
831 });
832
833 self.tool_use
834 .run_pending_tool(tool_use_id, insert_output_task);
835 }
836
837 pub fn insert_scripting_tool_output(
838 &mut self,
839 tool_use_id: LanguageModelToolUseId,
840 output: Task<Result<String>>,
841 cx: &mut Context<Self>,
842 ) {
843 let insert_output_task = cx.spawn(|thread, mut cx| {
844 let tool_use_id = tool_use_id.clone();
845 async move {
846 let output = output.await;
847 thread
848 .update(&mut cx, |thread, cx| {
849 let pending_tool_use = thread
850 .scripting_tool_use
851 .insert_tool_output(tool_use_id.clone(), output);
852
853 cx.emit(ThreadEvent::ToolFinished {
854 tool_use_id,
855 pending_tool_use,
856 });
857 })
858 .ok();
859 }
860 });
861
862 self.scripting_tool_use
863 .run_pending_tool(tool_use_id, insert_output_task);
864 }
865
866 pub fn send_tool_results_to_model(
867 &mut self,
868 model: Arc<dyn LanguageModel>,
869 updated_context: Vec<ContextSnapshot>,
870 cx: &mut Context<Self>,
871 ) {
872 self.context.extend(
873 updated_context
874 .into_iter()
875 .map(|context| (context.id, context)),
876 );
877
878 // Insert a user message to contain the tool results.
879 self.insert_user_message(
880 // TODO: Sending up a user message without any content results in the model sending back
881 // responses that also don't have any content. We currently don't handle this case well,
882 // so for now we provide some text to keep the model on track.
883 "Here are the tool results.",
884 Vec::new(),
885 cx,
886 );
887 self.send_to_model(model, RequestKind::Chat, cx);
888 }
889
890 /// Cancels the last pending completion, if there are any pending.
891 ///
892 /// Returns whether a completion was canceled.
893 pub fn cancel_last_completion(&mut self) -> bool {
894 if let Some(_last_completion) = self.pending_completions.pop() {
895 true
896 } else {
897 false
898 }
899 }
900
901 /// Reports feedback about the thread and stores it in our telemetry backend.
902 pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
903 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
904 let serialized_thread = self.serialize(cx);
905 let thread_id = self.id().clone();
906 let client = self.project.read(cx).client();
907
908 cx.background_spawn(async move {
909 let final_project_snapshot = final_project_snapshot.await;
910 let serialized_thread = serialized_thread.await?;
911 let thread_data =
912 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
913
914 let rating = if is_positive { "positive" } else { "negative" };
915 telemetry::event!(
916 "Assistant Thread Rated",
917 rating,
918 thread_id,
919 thread_data,
920 final_project_snapshot
921 );
922 client.telemetry().flush_events();
923
924 Ok(())
925 })
926 }
927
928 /// Create a snapshot of the current project state including git information and unsaved buffers.
929 fn project_snapshot(
930 project: Entity<Project>,
931 cx: &mut Context<Self>,
932 ) -> Task<Arc<ProjectSnapshot>> {
933 let worktree_snapshots: Vec<_> = project
934 .read(cx)
935 .visible_worktrees(cx)
936 .map(|worktree| Self::worktree_snapshot(worktree, cx))
937 .collect();
938
939 cx.spawn(move |_, cx| async move {
940 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
941
942 let mut unsaved_buffers = Vec::new();
943 cx.update(|app_cx| {
944 let buffer_store = project.read(app_cx).buffer_store();
945 for buffer_handle in buffer_store.read(app_cx).buffers() {
946 let buffer = buffer_handle.read(app_cx);
947 if buffer.is_dirty() {
948 if let Some(file) = buffer.file() {
949 let path = file.path().to_string_lossy().to_string();
950 unsaved_buffers.push(path);
951 }
952 }
953 }
954 })
955 .ok();
956
957 Arc::new(ProjectSnapshot {
958 worktree_snapshots,
959 unsaved_buffer_paths: unsaved_buffers,
960 timestamp: Utc::now(),
961 })
962 })
963 }
964
965 fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
966 cx.spawn(move |cx| async move {
967 // Get worktree path and snapshot
968 let worktree_info = cx.update(|app_cx| {
969 let worktree = worktree.read(app_cx);
970 let path = worktree.abs_path().to_string_lossy().to_string();
971 let snapshot = worktree.snapshot();
972 (path, snapshot)
973 });
974
975 let Ok((worktree_path, snapshot)) = worktree_info else {
976 return WorktreeSnapshot {
977 worktree_path: String::new(),
978 git_state: None,
979 };
980 };
981
982 // Extract git information
983 let git_state = match snapshot.repositories().first() {
984 None => None,
985 Some(repo_entry) => {
986 // Get branch information
987 let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
988
989 // Get repository info
990 let repo_result = worktree.read_with(&cx, |worktree, _cx| {
991 if let project::Worktree::Local(local_worktree) = &worktree {
992 local_worktree.get_local_repo(repo_entry).map(|local_repo| {
993 let repo = local_repo.repo();
994 (repo.remote_url("origin"), repo.head_sha(), repo.clone())
995 })
996 } else {
997 None
998 }
999 });
1000
1001 match repo_result {
1002 Ok(Some((remote_url, head_sha, repository))) => {
1003 // Get diff asynchronously
1004 let diff = repository
1005 .diff(git::repository::DiffType::HeadToWorktree, cx)
1006 .await
1007 .ok();
1008
1009 Some(GitState {
1010 remote_url,
1011 head_sha,
1012 current_branch,
1013 diff,
1014 })
1015 }
1016 Err(_) | Ok(None) => None,
1017 }
1018 }
1019 };
1020
1021 WorktreeSnapshot {
1022 worktree_path,
1023 git_state,
1024 }
1025 })
1026 }
1027
1028 pub fn to_markdown(&self) -> Result<String> {
1029 let mut markdown = Vec::new();
1030
1031 if let Some(summary) = self.summary() {
1032 writeln!(markdown, "# {summary}\n")?;
1033 };
1034
1035 for message in self.messages() {
1036 writeln!(
1037 markdown,
1038 "## {role}\n",
1039 role = match message.role {
1040 Role::User => "User",
1041 Role::Assistant => "Assistant",
1042 Role::System => "System",
1043 }
1044 )?;
1045 writeln!(markdown, "{}\n", message.text)?;
1046
1047 for tool_use in self.tool_uses_for_message(message.id) {
1048 writeln!(
1049 markdown,
1050 "**Use Tool: {} ({})**",
1051 tool_use.name, tool_use.id
1052 )?;
1053 writeln!(markdown, "```json")?;
1054 writeln!(
1055 markdown,
1056 "{}",
1057 serde_json::to_string_pretty(&tool_use.input)?
1058 )?;
1059 writeln!(markdown, "```")?;
1060 }
1061
1062 for tool_result in self.tool_results_for_message(message.id) {
1063 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1064 if tool_result.is_error {
1065 write!(markdown, " (Error)")?;
1066 }
1067
1068 writeln!(markdown, "**\n")?;
1069 writeln!(markdown, "{}", tool_result.content)?;
1070 }
1071 }
1072
1073 Ok(String::from_utf8_lossy(&markdown).to_string())
1074 }
1075
1076 pub fn action_log(&self) -> &Entity<ActionLog> {
1077 &self.action_log
1078 }
1079
1080 pub fn cumulative_token_usage(&self) -> TokenUsage {
1081 self.cumulative_token_usage.clone()
1082 }
1083}
1084
1085#[derive(Debug, Clone)]
1086pub enum ThreadError {
1087 PaymentRequired,
1088 MaxMonthlySpendReached,
1089 Message(SharedString),
1090}
1091
1092#[derive(Debug, Clone)]
1093pub enum ThreadEvent {
1094 ShowError(ThreadError),
1095 StreamedCompletion,
1096 StreamedAssistantText(MessageId, String),
1097 MessageAdded(MessageId),
1098 MessageEdited(MessageId),
1099 MessageDeleted(MessageId),
1100 SummaryChanged,
1101 UsePendingTools,
1102 ToolFinished {
1103 #[allow(unused)]
1104 tool_use_id: LanguageModelToolUseId,
1105 /// The pending tool use that corresponds to this tool.
1106 pending_tool_use: Option<PendingToolUse>,
1107 },
1108}
1109
1110impl EventEmitter<ThreadEvent> for Thread {}
1111
1112struct PendingCompletion {
1113 id: usize,
1114 _task: Task<()>,
1115}