1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use chrono::{DateTime, Utc};
6use collections::{BTreeMap, HashMap, HashSet};
7use futures::StreamExt as _;
8use gpui::{App, Context, EventEmitter, SharedString, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
11 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
12 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
13 Role, StopReason,
14};
15use serde::{Deserialize, Serialize};
16use util::{post_inc, TryFutureExt as _};
17use uuid::Uuid;
18
19use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
20use crate::thread_store::SavedThread;
21use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
22
23#[derive(Debug, Clone, Copy)]
24pub enum RequestKind {
25 Chat,
26 /// Used when summarizing a thread.
27 Summarize,
28}
29
30#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
31pub struct ThreadId(Arc<str>);
32
33impl ThreadId {
34 pub fn new() -> Self {
35 Self(Uuid::new_v4().to_string().into())
36 }
37}
38
39impl std::fmt::Display for ThreadId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
46pub struct MessageId(pub(crate) usize);
47
48impl MessageId {
49 fn post_inc(&mut self) -> Self {
50 Self(post_inc(&mut self.0))
51 }
52}
53
54/// A message in a [`Thread`].
55#[derive(Debug, Clone)]
56pub struct Message {
57 pub id: MessageId,
58 pub role: Role,
59 pub text: String,
60}
61
62/// A thread of conversation with the LLM.
63pub struct Thread {
64 id: ThreadId,
65 updated_at: DateTime<Utc>,
66 summary: Option<SharedString>,
67 pending_summary: Task<Option<()>>,
68 messages: Vec<Message>,
69 next_message_id: MessageId,
70 context: BTreeMap<ContextId, ContextSnapshot>,
71 context_by_message: HashMap<MessageId, Vec<ContextId>>,
72 completion_count: usize,
73 pending_completions: Vec<PendingCompletion>,
74 tools: Arc<ToolWorkingSet>,
75 tool_use: ToolUseState,
76}
77
78impl Thread {
79 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
80 Self {
81 id: ThreadId::new(),
82 updated_at: Utc::now(),
83 summary: None,
84 pending_summary: Task::ready(None),
85 messages: Vec::new(),
86 next_message_id: MessageId(0),
87 context: BTreeMap::default(),
88 context_by_message: HashMap::default(),
89 completion_count: 0,
90 pending_completions: Vec::new(),
91 tools,
92 tool_use: ToolUseState::new(),
93 }
94 }
95
96 pub fn from_saved(
97 id: ThreadId,
98 saved: SavedThread,
99 tools: Arc<ToolWorkingSet>,
100 _cx: &mut Context<Self>,
101 ) -> Self {
102 let next_message_id = MessageId(saved.messages.len());
103 let tool_use = ToolUseState::from_saved_messages(&saved.messages);
104
105 Self {
106 id,
107 updated_at: saved.updated_at,
108 summary: Some(saved.summary),
109 pending_summary: Task::ready(None),
110 messages: saved
111 .messages
112 .into_iter()
113 .map(|message| Message {
114 id: message.id,
115 role: message.role,
116 text: message.text,
117 })
118 .collect(),
119 next_message_id,
120 context: BTreeMap::default(),
121 context_by_message: HashMap::default(),
122 completion_count: 0,
123 pending_completions: Vec::new(),
124 tools,
125 tool_use,
126 }
127 }
128
129 pub fn id(&self) -> &ThreadId {
130 &self.id
131 }
132
133 pub fn is_empty(&self) -> bool {
134 self.messages.is_empty()
135 }
136
137 pub fn updated_at(&self) -> DateTime<Utc> {
138 self.updated_at
139 }
140
141 pub fn touch_updated_at(&mut self) {
142 self.updated_at = Utc::now();
143 }
144
145 pub fn summary(&self) -> Option<SharedString> {
146 self.summary.clone()
147 }
148
149 pub fn summary_or_default(&self) -> SharedString {
150 const DEFAULT: SharedString = SharedString::new_static("New Thread");
151 self.summary.clone().unwrap_or(DEFAULT)
152 }
153
154 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
155 self.summary = Some(summary.into());
156 cx.emit(ThreadEvent::SummaryChanged);
157 }
158
159 pub fn message(&self, id: MessageId) -> Option<&Message> {
160 self.messages.iter().find(|message| message.id == id)
161 }
162
163 pub fn messages(&self) -> impl Iterator<Item = &Message> {
164 self.messages.iter()
165 }
166
167 pub fn is_streaming(&self) -> bool {
168 !self.pending_completions.is_empty()
169 }
170
171 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
172 &self.tools
173 }
174
175 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
176 let context = self.context_by_message.get(&id)?;
177 Some(
178 context
179 .into_iter()
180 .filter_map(|context_id| self.context.get(&context_id))
181 .cloned()
182 .collect::<Vec<_>>(),
183 )
184 }
185
186 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
187 self.tool_use.pending_tool_uses()
188 }
189
190 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
191 self.tool_use.tool_uses_for_message(id)
192 }
193
194 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
195 self.tool_use.tool_results_for_message(id)
196 }
197
198 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
199 self.tool_use.message_has_tool_results(message_id)
200 }
201
202 pub fn insert_user_message(
203 &mut self,
204 text: impl Into<String>,
205 context: Vec<ContextSnapshot>,
206 cx: &mut Context<Self>,
207 ) {
208 let message_id = self.insert_message(Role::User, text, cx);
209 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
210 self.context
211 .extend(context.into_iter().map(|context| (context.id, context)));
212 self.context_by_message.insert(message_id, context_ids);
213 }
214
215 pub fn insert_message(
216 &mut self,
217 role: Role,
218 text: impl Into<String>,
219 cx: &mut Context<Self>,
220 ) -> MessageId {
221 let id = self.next_message_id.post_inc();
222 self.messages.push(Message {
223 id,
224 role,
225 text: text.into(),
226 });
227 self.touch_updated_at();
228 cx.emit(ThreadEvent::MessageAdded(id));
229 id
230 }
231
232 /// Returns the representation of this [`Thread`] in a textual form.
233 ///
234 /// This is the representation we use when attaching a thread as context to another thread.
235 pub fn text(&self) -> String {
236 let mut text = String::new();
237
238 for message in &self.messages {
239 text.push_str(match message.role {
240 language_model::Role::User => "User:",
241 language_model::Role::Assistant => "Assistant:",
242 language_model::Role::System => "System:",
243 });
244 text.push('\n');
245
246 text.push_str(&message.text);
247 text.push('\n');
248 }
249
250 text
251 }
252
253 pub fn send_to_model(
254 &mut self,
255 model: Arc<dyn LanguageModel>,
256 request_kind: RequestKind,
257 use_tools: bool,
258 cx: &mut Context<Self>,
259 ) {
260 let mut request = self.to_completion_request(request_kind, cx);
261
262 if use_tools {
263 request.tools = self
264 .tools()
265 .tools(cx)
266 .into_iter()
267 .map(|tool| LanguageModelRequestTool {
268 name: tool.name(),
269 description: tool.description(),
270 input_schema: tool.input_schema(),
271 })
272 .collect();
273 }
274
275 self.stream_completion(request, model, cx);
276 }
277
278 pub fn to_completion_request(
279 &self,
280 request_kind: RequestKind,
281 _cx: &App,
282 ) -> LanguageModelRequest {
283 let mut request = LanguageModelRequest {
284 messages: vec![],
285 tools: Vec::new(),
286 stop: Vec::new(),
287 temperature: None,
288 };
289
290 let mut referenced_context_ids = HashSet::default();
291
292 for message in &self.messages {
293 if let Some(context_ids) = self.context_by_message.get(&message.id) {
294 referenced_context_ids.extend(context_ids);
295 }
296
297 let mut request_message = LanguageModelRequestMessage {
298 role: message.role,
299 content: Vec::new(),
300 cache: false,
301 };
302 match request_kind {
303 RequestKind::Chat => {
304 self.tool_use
305 .attach_tool_results(message.id, &mut request_message);
306 }
307 RequestKind::Summarize => {
308 // We don't care about tool use during summarization.
309 }
310 }
311
312 if !message.text.is_empty() {
313 request_message
314 .content
315 .push(MessageContent::Text(message.text.clone()));
316 }
317
318 match request_kind {
319 RequestKind::Chat => {
320 self.tool_use
321 .attach_tool_uses(message.id, &mut request_message);
322 }
323 RequestKind::Summarize => {
324 // We don't care about tool use during summarization.
325 }
326 }
327
328 request.messages.push(request_message);
329 }
330
331 if !referenced_context_ids.is_empty() {
332 let mut context_message = LanguageModelRequestMessage {
333 role: Role::User,
334 content: Vec::new(),
335 cache: false,
336 };
337
338 let referenced_context = referenced_context_ids
339 .into_iter()
340 .filter_map(|context_id| self.context.get(context_id))
341 .cloned();
342 attach_context_to_message(&mut context_message, referenced_context);
343
344 request.messages.push(context_message);
345 }
346
347 request
348 }
349
350 pub fn stream_completion(
351 &mut self,
352 request: LanguageModelRequest,
353 model: Arc<dyn LanguageModel>,
354 cx: &mut Context<Self>,
355 ) {
356 let pending_completion_id = post_inc(&mut self.completion_count);
357
358 let task = cx.spawn(|thread, mut cx| async move {
359 let stream = model.stream_completion(request, &cx);
360 let stream_completion = async {
361 let mut events = stream.await?;
362 let mut stop_reason = StopReason::EndTurn;
363
364 while let Some(event) = events.next().await {
365 let event = event?;
366
367 thread.update(&mut cx, |thread, cx| {
368 match event {
369 LanguageModelCompletionEvent::StartMessage { .. } => {
370 thread.insert_message(Role::Assistant, String::new(), cx);
371 }
372 LanguageModelCompletionEvent::Stop(reason) => {
373 stop_reason = reason;
374 }
375 LanguageModelCompletionEvent::Text(chunk) => {
376 if let Some(last_message) = thread.messages.last_mut() {
377 if last_message.role == Role::Assistant {
378 last_message.text.push_str(&chunk);
379 cx.emit(ThreadEvent::StreamedAssistantText(
380 last_message.id,
381 chunk,
382 ));
383 } else {
384 // If we won't have an Assistant message yet, assume this chunk marks the beginning
385 // of a new Assistant response.
386 //
387 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
388 // will result in duplicating the text of the chunk in the rendered Markdown.
389 thread.insert_message(Role::Assistant, chunk, cx);
390 }
391 }
392 }
393 LanguageModelCompletionEvent::ToolUse(tool_use) => {
394 if let Some(last_assistant_message) = thread
395 .messages
396 .iter()
397 .rfind(|message| message.role == Role::Assistant)
398 {
399 thread
400 .tool_use
401 .request_tool_use(last_assistant_message.id, tool_use);
402 }
403 }
404 }
405
406 thread.touch_updated_at();
407 cx.emit(ThreadEvent::StreamedCompletion);
408 cx.notify();
409 })?;
410
411 smol::future::yield_now().await;
412 }
413
414 thread.update(&mut cx, |thread, cx| {
415 thread
416 .pending_completions
417 .retain(|completion| completion.id != pending_completion_id);
418
419 if thread.summary.is_none() && thread.messages.len() >= 2 {
420 thread.summarize(cx);
421 }
422 })?;
423
424 anyhow::Ok(stop_reason)
425 };
426
427 let result = stream_completion.await;
428
429 thread
430 .update(&mut cx, |thread, cx| match result.as_ref() {
431 Ok(stop_reason) => match stop_reason {
432 StopReason::ToolUse => {
433 cx.emit(ThreadEvent::UsePendingTools);
434 }
435 StopReason::EndTurn => {}
436 StopReason::MaxTokens => {}
437 },
438 Err(error) => {
439 if error.is::<PaymentRequiredError>() {
440 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
441 } else if error.is::<MaxMonthlySpendReachedError>() {
442 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
443 } else {
444 let error_message = error
445 .chain()
446 .map(|err| err.to_string())
447 .collect::<Vec<_>>()
448 .join("\n");
449 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
450 SharedString::from(error_message.clone()),
451 )));
452 }
453
454 thread.cancel_last_completion();
455 }
456 })
457 .ok();
458 });
459
460 self.pending_completions.push(PendingCompletion {
461 id: pending_completion_id,
462 _task: task,
463 });
464 }
465
466 pub fn summarize(&mut self, cx: &mut Context<Self>) {
467 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
468 return;
469 };
470 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
471 return;
472 };
473
474 if !provider.is_authenticated(cx) {
475 return;
476 }
477
478 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
479 request.messages.push(LanguageModelRequestMessage {
480 role: Role::User,
481 content: vec![
482 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
483 .into(),
484 ],
485 cache: false,
486 });
487
488 self.pending_summary = cx.spawn(|this, mut cx| {
489 async move {
490 let stream = model.stream_completion_text(request, &cx);
491 let mut messages = stream.await?;
492
493 let mut new_summary = String::new();
494 while let Some(message) = messages.stream.next().await {
495 let text = message?;
496 let mut lines = text.lines();
497 new_summary.extend(lines.next());
498
499 // Stop if the LLM generated multiple lines.
500 if lines.next().is_some() {
501 break;
502 }
503 }
504
505 this.update(&mut cx, |this, cx| {
506 if !new_summary.is_empty() {
507 this.summary = Some(new_summary.into());
508 }
509
510 cx.emit(ThreadEvent::SummaryChanged);
511 })?;
512
513 anyhow::Ok(())
514 }
515 .log_err()
516 });
517 }
518
519 pub fn insert_tool_output(
520 &mut self,
521 tool_use_id: LanguageModelToolUseId,
522 output: Task<Result<String>>,
523 cx: &mut Context<Self>,
524 ) {
525 let insert_output_task = cx.spawn(|thread, mut cx| {
526 let tool_use_id = tool_use_id.clone();
527 async move {
528 let output = output.await;
529 thread
530 .update(&mut cx, |thread, cx| {
531 thread
532 .tool_use
533 .insert_tool_output(tool_use_id.clone(), output);
534
535 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
536 })
537 .ok();
538 }
539 });
540
541 self.tool_use
542 .run_pending_tool(tool_use_id, insert_output_task);
543 }
544
545 /// Cancels the last pending completion, if there are any pending.
546 ///
547 /// Returns whether a completion was canceled.
548 pub fn cancel_last_completion(&mut self) -> bool {
549 if let Some(_last_completion) = self.pending_completions.pop() {
550 true
551 } else {
552 false
553 }
554 }
555}
556
557#[derive(Debug, Clone)]
558pub enum ThreadError {
559 PaymentRequired,
560 MaxMonthlySpendReached,
561 Message(SharedString),
562}
563
564#[derive(Debug, Clone)]
565pub enum ThreadEvent {
566 ShowError(ThreadError),
567 StreamedCompletion,
568 StreamedAssistantText(MessageId, String),
569 MessageAdded(MessageId),
570 SummaryChanged,
571 UsePendingTools,
572 ToolFinished {
573 #[allow(unused)]
574 tool_use_id: LanguageModelToolUseId,
575 },
576}
577
578impl EventEmitter<ThreadEvent> for Thread {}
579
580struct PendingCompletion {
581 id: usize,
582 _task: Task<()>,
583}