thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use chrono::{DateTime, Utc};
  6use collections::{BTreeMap, HashMap, HashSet};
  7use futures::StreamExt as _;
  8use gpui::{App, Context, EventEmitter, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 12    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
 13    Role, StopReason,
 14};
 15use serde::{Deserialize, Serialize};
 16use util::{post_inc, TryFutureExt as _};
 17use uuid::Uuid;
 18
 19use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 20use crate::thread_store::SavedThread;
 21use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 22
 23#[derive(Debug, Clone, Copy)]
 24pub enum RequestKind {
 25    Chat,
 26    /// Used when summarizing a thread.
 27    Summarize,
 28}
 29
 30#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 31pub struct ThreadId(Arc<str>);
 32
 33impl ThreadId {
 34    pub fn new() -> Self {
 35        Self(Uuid::new_v4().to_string().into())
 36    }
 37}
 38
 39impl std::fmt::Display for ThreadId {
 40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 41        write!(f, "{}", self.0)
 42    }
 43}
 44
 45#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 46pub struct MessageId(pub(crate) usize);
 47
 48impl MessageId {
 49    fn post_inc(&mut self) -> Self {
 50        Self(post_inc(&mut self.0))
 51    }
 52}
 53
 54/// A message in a [`Thread`].
 55#[derive(Debug, Clone)]
 56pub struct Message {
 57    pub id: MessageId,
 58    pub role: Role,
 59    pub text: String,
 60}
 61
 62/// A thread of conversation with the LLM.
 63pub struct Thread {
 64    id: ThreadId,
 65    updated_at: DateTime<Utc>,
 66    summary: Option<SharedString>,
 67    pending_summary: Task<Option<()>>,
 68    messages: Vec<Message>,
 69    next_message_id: MessageId,
 70    context: BTreeMap<ContextId, ContextSnapshot>,
 71    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 72    completion_count: usize,
 73    pending_completions: Vec<PendingCompletion>,
 74    tools: Arc<ToolWorkingSet>,
 75    tool_use: ToolUseState,
 76}
 77
 78impl Thread {
 79    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
 80        Self {
 81            id: ThreadId::new(),
 82            updated_at: Utc::now(),
 83            summary: None,
 84            pending_summary: Task::ready(None),
 85            messages: Vec::new(),
 86            next_message_id: MessageId(0),
 87            context: BTreeMap::default(),
 88            context_by_message: HashMap::default(),
 89            completion_count: 0,
 90            pending_completions: Vec::new(),
 91            tools,
 92            tool_use: ToolUseState::new(),
 93        }
 94    }
 95
 96    pub fn from_saved(
 97        id: ThreadId,
 98        saved: SavedThread,
 99        tools: Arc<ToolWorkingSet>,
100        _cx: &mut Context<Self>,
101    ) -> Self {
102        let next_message_id = MessageId(
103            saved
104                .messages
105                .last()
106                .map(|message| message.id.0 + 1)
107                .unwrap_or(0),
108        );
109        let tool_use = ToolUseState::from_saved_messages(&saved.messages);
110
111        Self {
112            id,
113            updated_at: saved.updated_at,
114            summary: Some(saved.summary),
115            pending_summary: Task::ready(None),
116            messages: saved
117                .messages
118                .into_iter()
119                .map(|message| Message {
120                    id: message.id,
121                    role: message.role,
122                    text: message.text,
123                })
124                .collect(),
125            next_message_id,
126            context: BTreeMap::default(),
127            context_by_message: HashMap::default(),
128            completion_count: 0,
129            pending_completions: Vec::new(),
130            tools,
131            tool_use,
132        }
133    }
134
135    pub fn id(&self) -> &ThreadId {
136        &self.id
137    }
138
139    pub fn is_empty(&self) -> bool {
140        self.messages.is_empty()
141    }
142
143    pub fn updated_at(&self) -> DateTime<Utc> {
144        self.updated_at
145    }
146
147    pub fn touch_updated_at(&mut self) {
148        self.updated_at = Utc::now();
149    }
150
151    pub fn summary(&self) -> Option<SharedString> {
152        self.summary.clone()
153    }
154
155    pub fn summary_or_default(&self) -> SharedString {
156        const DEFAULT: SharedString = SharedString::new_static("New Thread");
157        self.summary.clone().unwrap_or(DEFAULT)
158    }
159
160    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
161        self.summary = Some(summary.into());
162        cx.emit(ThreadEvent::SummaryChanged);
163    }
164
165    pub fn message(&self, id: MessageId) -> Option<&Message> {
166        self.messages.iter().find(|message| message.id == id)
167    }
168
169    pub fn messages(&self) -> impl Iterator<Item = &Message> {
170        self.messages.iter()
171    }
172
173    pub fn is_streaming(&self) -> bool {
174        !self.pending_completions.is_empty()
175    }
176
177    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
178        &self.tools
179    }
180
181    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
182        let context = self.context_by_message.get(&id)?;
183        Some(
184            context
185                .into_iter()
186                .filter_map(|context_id| self.context.get(&context_id))
187                .cloned()
188                .collect::<Vec<_>>(),
189        )
190    }
191
192    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
193        self.tool_use.pending_tool_uses()
194    }
195
196    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
197        self.tool_use.tool_uses_for_message(id)
198    }
199
200    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
201        self.tool_use.tool_results_for_message(id)
202    }
203
204    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
205        self.tool_use.message_has_tool_results(message_id)
206    }
207
208    pub fn insert_user_message(
209        &mut self,
210        text: impl Into<String>,
211        context: Vec<ContextSnapshot>,
212        cx: &mut Context<Self>,
213    ) {
214        let message_id = self.insert_message(Role::User, text, cx);
215        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
216        self.context
217            .extend(context.into_iter().map(|context| (context.id, context)));
218        self.context_by_message.insert(message_id, context_ids);
219    }
220
221    pub fn insert_message(
222        &mut self,
223        role: Role,
224        text: impl Into<String>,
225        cx: &mut Context<Self>,
226    ) -> MessageId {
227        let id = self.next_message_id.post_inc();
228        self.messages.push(Message {
229            id,
230            role,
231            text: text.into(),
232        });
233        self.touch_updated_at();
234        cx.emit(ThreadEvent::MessageAdded(id));
235        id
236    }
237
238    pub fn edit_message(
239        &mut self,
240        id: MessageId,
241        new_role: Role,
242        new_text: String,
243        cx: &mut Context<Self>,
244    ) -> bool {
245        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
246            return false;
247        };
248        message.role = new_role;
249        message.text = new_text;
250        self.touch_updated_at();
251        cx.emit(ThreadEvent::MessageEdited(id));
252        true
253    }
254
255    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
256        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
257            return false;
258        };
259        self.messages.remove(index);
260        self.context_by_message.remove(&id);
261        self.touch_updated_at();
262        cx.emit(ThreadEvent::MessageDeleted(id));
263        true
264    }
265
266    /// Returns the representation of this [`Thread`] in a textual form.
267    ///
268    /// This is the representation we use when attaching a thread as context to another thread.
269    pub fn text(&self) -> String {
270        let mut text = String::new();
271
272        for message in &self.messages {
273            text.push_str(match message.role {
274                language_model::Role::User => "User:",
275                language_model::Role::Assistant => "Assistant:",
276                language_model::Role::System => "System:",
277            });
278            text.push('\n');
279
280            text.push_str(&message.text);
281            text.push('\n');
282        }
283
284        text
285    }
286
287    pub fn send_to_model(
288        &mut self,
289        model: Arc<dyn LanguageModel>,
290        request_kind: RequestKind,
291        use_tools: bool,
292        cx: &mut Context<Self>,
293    ) {
294        let mut request = self.to_completion_request(request_kind, cx);
295
296        if use_tools {
297            request.tools = self
298                .tools()
299                .tools(cx)
300                .into_iter()
301                .map(|tool| LanguageModelRequestTool {
302                    name: tool.name(),
303                    description: tool.description(),
304                    input_schema: tool.input_schema(),
305                })
306                .collect();
307        }
308
309        self.stream_completion(request, model, cx);
310    }
311
312    pub fn to_completion_request(
313        &self,
314        request_kind: RequestKind,
315        _cx: &App,
316    ) -> LanguageModelRequest {
317        let mut request = LanguageModelRequest {
318            messages: vec![],
319            tools: Vec::new(),
320            stop: Vec::new(),
321            temperature: None,
322        };
323
324        let mut referenced_context_ids = HashSet::default();
325
326        for message in &self.messages {
327            if let Some(context_ids) = self.context_by_message.get(&message.id) {
328                referenced_context_ids.extend(context_ids);
329            }
330
331            let mut request_message = LanguageModelRequestMessage {
332                role: message.role,
333                content: Vec::new(),
334                cache: false,
335            };
336            match request_kind {
337                RequestKind::Chat => {
338                    self.tool_use
339                        .attach_tool_results(message.id, &mut request_message);
340                }
341                RequestKind::Summarize => {
342                    // We don't care about tool use during summarization.
343                }
344            }
345
346            if !message.text.is_empty() {
347                request_message
348                    .content
349                    .push(MessageContent::Text(message.text.clone()));
350            }
351
352            match request_kind {
353                RequestKind::Chat => {
354                    self.tool_use
355                        .attach_tool_uses(message.id, &mut request_message);
356                }
357                RequestKind::Summarize => {
358                    // We don't care about tool use during summarization.
359                }
360            }
361
362            request.messages.push(request_message);
363        }
364
365        if !referenced_context_ids.is_empty() {
366            let mut context_message = LanguageModelRequestMessage {
367                role: Role::User,
368                content: Vec::new(),
369                cache: false,
370            };
371
372            let referenced_context = referenced_context_ids
373                .into_iter()
374                .filter_map(|context_id| self.context.get(context_id))
375                .cloned();
376            attach_context_to_message(&mut context_message, referenced_context);
377
378            request.messages.push(context_message);
379        }
380
381        request
382    }
383
384    pub fn stream_completion(
385        &mut self,
386        request: LanguageModelRequest,
387        model: Arc<dyn LanguageModel>,
388        cx: &mut Context<Self>,
389    ) {
390        let pending_completion_id = post_inc(&mut self.completion_count);
391
392        let task = cx.spawn(|thread, mut cx| async move {
393            let stream = model.stream_completion(request, &cx);
394            let stream_completion = async {
395                let mut events = stream.await?;
396                let mut stop_reason = StopReason::EndTurn;
397
398                while let Some(event) = events.next().await {
399                    let event = event?;
400
401                    thread.update(&mut cx, |thread, cx| {
402                        match event {
403                            LanguageModelCompletionEvent::StartMessage { .. } => {
404                                thread.insert_message(Role::Assistant, String::new(), cx);
405                            }
406                            LanguageModelCompletionEvent::Stop(reason) => {
407                                stop_reason = reason;
408                            }
409                            LanguageModelCompletionEvent::Text(chunk) => {
410                                if let Some(last_message) = thread.messages.last_mut() {
411                                    if last_message.role == Role::Assistant {
412                                        last_message.text.push_str(&chunk);
413                                        cx.emit(ThreadEvent::StreamedAssistantText(
414                                            last_message.id,
415                                            chunk,
416                                        ));
417                                    } else {
418                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
419                                        // of a new Assistant response.
420                                        //
421                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
422                                        // will result in duplicating the text of the chunk in the rendered Markdown.
423                                        thread.insert_message(Role::Assistant, chunk, cx);
424                                    }
425                                }
426                            }
427                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
428                                if let Some(last_assistant_message) = thread
429                                    .messages
430                                    .iter()
431                                    .rfind(|message| message.role == Role::Assistant)
432                                {
433                                    thread
434                                        .tool_use
435                                        .request_tool_use(last_assistant_message.id, tool_use);
436                                }
437                            }
438                        }
439
440                        thread.touch_updated_at();
441                        cx.emit(ThreadEvent::StreamedCompletion);
442                        cx.notify();
443                    })?;
444
445                    smol::future::yield_now().await;
446                }
447
448                thread.update(&mut cx, |thread, cx| {
449                    thread
450                        .pending_completions
451                        .retain(|completion| completion.id != pending_completion_id);
452
453                    if thread.summary.is_none() && thread.messages.len() >= 2 {
454                        thread.summarize(cx);
455                    }
456                })?;
457
458                anyhow::Ok(stop_reason)
459            };
460
461            let result = stream_completion.await;
462
463            thread
464                .update(&mut cx, |thread, cx| match result.as_ref() {
465                    Ok(stop_reason) => match stop_reason {
466                        StopReason::ToolUse => {
467                            cx.emit(ThreadEvent::UsePendingTools);
468                        }
469                        StopReason::EndTurn => {}
470                        StopReason::MaxTokens => {}
471                    },
472                    Err(error) => {
473                        if error.is::<PaymentRequiredError>() {
474                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
475                        } else if error.is::<MaxMonthlySpendReachedError>() {
476                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
477                        } else {
478                            let error_message = error
479                                .chain()
480                                .map(|err| err.to_string())
481                                .collect::<Vec<_>>()
482                                .join("\n");
483                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
484                                SharedString::from(error_message.clone()),
485                            )));
486                        }
487
488                        thread.cancel_last_completion();
489                    }
490                })
491                .ok();
492        });
493
494        self.pending_completions.push(PendingCompletion {
495            id: pending_completion_id,
496            _task: task,
497        });
498    }
499
500    pub fn summarize(&mut self, cx: &mut Context<Self>) {
501        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
502            return;
503        };
504        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
505            return;
506        };
507
508        if !provider.is_authenticated(cx) {
509            return;
510        }
511
512        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
513        request.messages.push(LanguageModelRequestMessage {
514            role: Role::User,
515            content: vec![
516                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
517                    .into(),
518            ],
519            cache: false,
520        });
521
522        self.pending_summary = cx.spawn(|this, mut cx| {
523            async move {
524                let stream = model.stream_completion_text(request, &cx);
525                let mut messages = stream.await?;
526
527                let mut new_summary = String::new();
528                while let Some(message) = messages.stream.next().await {
529                    let text = message?;
530                    let mut lines = text.lines();
531                    new_summary.extend(lines.next());
532
533                    // Stop if the LLM generated multiple lines.
534                    if lines.next().is_some() {
535                        break;
536                    }
537                }
538
539                this.update(&mut cx, |this, cx| {
540                    if !new_summary.is_empty() {
541                        this.summary = Some(new_summary.into());
542                    }
543
544                    cx.emit(ThreadEvent::SummaryChanged);
545                })?;
546
547                anyhow::Ok(())
548            }
549            .log_err()
550        });
551    }
552
553    pub fn insert_tool_output(
554        &mut self,
555        tool_use_id: LanguageModelToolUseId,
556        output: Task<Result<String>>,
557        cx: &mut Context<Self>,
558    ) {
559        let insert_output_task = cx.spawn(|thread, mut cx| {
560            let tool_use_id = tool_use_id.clone();
561            async move {
562                let output = output.await;
563                thread
564                    .update(&mut cx, |thread, cx| {
565                        thread
566                            .tool_use
567                            .insert_tool_output(tool_use_id.clone(), output);
568
569                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
570                    })
571                    .ok();
572            }
573        });
574
575        self.tool_use
576            .run_pending_tool(tool_use_id, insert_output_task);
577    }
578
579    /// Cancels the last pending completion, if there are any pending.
580    ///
581    /// Returns whether a completion was canceled.
582    pub fn cancel_last_completion(&mut self) -> bool {
583        if let Some(_last_completion) = self.pending_completions.pop() {
584            true
585        } else {
586            false
587        }
588    }
589}
590
591#[derive(Debug, Clone)]
592pub enum ThreadError {
593    PaymentRequired,
594    MaxMonthlySpendReached,
595    Message(SharedString),
596}
597
598#[derive(Debug, Clone)]
599pub enum ThreadEvent {
600    ShowError(ThreadError),
601    StreamedCompletion,
602    StreamedAssistantText(MessageId, String),
603    MessageAdded(MessageId),
604    MessageEdited(MessageId),
605    MessageDeleted(MessageId),
606    SummaryChanged,
607    UsePendingTools,
608    ToolFinished {
609        #[allow(unused)]
610        tool_use_id: LanguageModelToolUseId,
611    },
612}
613
614impl EventEmitter<ThreadEvent> for Thread {}
615
616struct PendingCompletion {
617    id: usize,
618    _task: Task<()>,
619}