thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12    StopReason,
 13};
 14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
 15use serde::{Deserialize, Serialize};
 16use util::post_inc;
 17use uuid::Uuid;
 18
 19#[derive(Debug, Clone, Copy)]
 20pub enum RequestKind {
 21    Chat,
 22}
 23
 24#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 25pub struct ThreadId(Arc<str>);
 26
 27impl ThreadId {
 28    pub fn new() -> Self {
 29        Self(Uuid::new_v4().to_string().into())
 30    }
 31}
 32
 33impl std::fmt::Display for ThreadId {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        write!(f, "{}", self.0)
 36    }
 37}
 38
 39#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 40pub struct MessageId(usize);
 41
 42impl MessageId {
 43    fn post_inc(&mut self) -> Self {
 44        Self(post_inc(&mut self.0))
 45    }
 46}
 47
 48/// A message in a [`Thread`].
 49#[derive(Debug, Clone)]
 50pub struct Message {
 51    pub id: MessageId,
 52    pub role: Role,
 53    pub text: String,
 54}
 55
 56/// A thread of conversation with the LLM.
 57pub struct Thread {
 58    id: ThreadId,
 59    messages: Vec<Message>,
 60    next_message_id: MessageId,
 61    completion_count: usize,
 62    pending_completions: Vec<PendingCompletion>,
 63    tools: Arc<ToolWorkingSet>,
 64    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 65    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
 66    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 67}
 68
 69impl Thread {
 70    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 71        Self {
 72            id: ThreadId::new(),
 73            messages: Vec::new(),
 74            next_message_id: MessageId(0),
 75            completion_count: 0,
 76            pending_completions: Vec::new(),
 77            tools,
 78            tool_uses_by_message: HashMap::default(),
 79            tool_results_by_message: HashMap::default(),
 80            pending_tool_uses_by_id: HashMap::default(),
 81        }
 82    }
 83
 84    pub fn id(&self) -> &ThreadId {
 85        &self.id
 86    }
 87
 88    pub fn message(&self, id: MessageId) -> Option<&Message> {
 89        self.messages.iter().find(|message| message.id == id)
 90    }
 91
 92    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 93        self.messages.iter()
 94    }
 95
 96    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 97        &self.tools
 98    }
 99
100    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
101        self.pending_tool_uses_by_id.values().collect()
102    }
103
104    pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
105        self.insert_message(Role::User, text, cx)
106    }
107
108    pub fn insert_message(
109        &mut self,
110        role: Role,
111        text: impl Into<String>,
112        cx: &mut ModelContext<Self>,
113    ) {
114        let id = self.next_message_id.post_inc();
115        self.messages.push(Message {
116            id,
117            role,
118            text: text.into(),
119        });
120        cx.emit(ThreadEvent::MessageAdded(id));
121    }
122
123    pub fn to_completion_request(
124        &self,
125        _request_kind: RequestKind,
126        _cx: &AppContext,
127    ) -> LanguageModelRequest {
128        let mut request = LanguageModelRequest {
129            messages: vec![],
130            tools: Vec::new(),
131            stop: Vec::new(),
132            temperature: None,
133        };
134
135        for message in &self.messages {
136            let mut request_message = LanguageModelRequestMessage {
137                role: message.role,
138                content: Vec::new(),
139                cache: false,
140            };
141
142            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
143                for tool_result in tool_results {
144                    request_message
145                        .content
146                        .push(MessageContent::ToolResult(tool_result.clone()));
147                }
148            }
149
150            if !message.text.is_empty() {
151                request_message
152                    .content
153                    .push(MessageContent::Text(message.text.clone()));
154            }
155
156            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
157                for tool_use in tool_uses {
158                    request_message
159                        .content
160                        .push(MessageContent::ToolUse(tool_use.clone()));
161                }
162            }
163
164            request.messages.push(request_message);
165        }
166
167        request
168    }
169
170    pub fn stream_completion(
171        &mut self,
172        request: LanguageModelRequest,
173        model: Arc<dyn LanguageModel>,
174        cx: &mut ModelContext<Self>,
175    ) {
176        let pending_completion_id = post_inc(&mut self.completion_count);
177
178        let task = cx.spawn(|thread, mut cx| async move {
179            let stream = model.stream_completion(request, &cx);
180            let stream_completion = async {
181                let mut events = stream.await?;
182                let mut stop_reason = StopReason::EndTurn;
183
184                while let Some(event) = events.next().await {
185                    let event = event?;
186
187                    thread.update(&mut cx, |thread, cx| {
188                        match event {
189                            LanguageModelCompletionEvent::StartMessage { .. } => {
190                                let id = thread.next_message_id.post_inc();
191                                thread.messages.push(Message {
192                                    id,
193                                    role: Role::Assistant,
194                                    text: String::new(),
195                                });
196                                cx.emit(ThreadEvent::MessageAdded(id));
197                            }
198                            LanguageModelCompletionEvent::Stop(reason) => {
199                                stop_reason = reason;
200                            }
201                            LanguageModelCompletionEvent::Text(chunk) => {
202                                if let Some(last_message) = thread.messages.last_mut() {
203                                    if last_message.role == Role::Assistant {
204                                        last_message.text.push_str(&chunk);
205                                        cx.emit(ThreadEvent::StreamedAssistantText(
206                                            last_message.id,
207                                            chunk,
208                                        ));
209                                    }
210                                }
211                            }
212                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
213                                if let Some(last_assistant_message) = thread
214                                    .messages
215                                    .iter()
216                                    .rfind(|message| message.role == Role::Assistant)
217                                {
218                                    thread
219                                        .tool_uses_by_message
220                                        .entry(last_assistant_message.id)
221                                        .or_default()
222                                        .push(tool_use.clone());
223
224                                    thread.pending_tool_uses_by_id.insert(
225                                        tool_use.id.clone(),
226                                        PendingToolUse {
227                                            assistant_message_id: last_assistant_message.id,
228                                            id: tool_use.id,
229                                            name: tool_use.name,
230                                            input: tool_use.input,
231                                            status: PendingToolUseStatus::Idle,
232                                        },
233                                    );
234                                }
235                            }
236                        }
237
238                        cx.emit(ThreadEvent::StreamedCompletion);
239                        cx.notify();
240                    })?;
241
242                    smol::future::yield_now().await;
243                }
244
245                thread.update(&mut cx, |thread, _cx| {
246                    thread
247                        .pending_completions
248                        .retain(|completion| completion.id != pending_completion_id);
249                })?;
250
251                anyhow::Ok(stop_reason)
252            };
253
254            let result = stream_completion.await;
255
256            thread
257                .update(&mut cx, |_thread, cx| match result.as_ref() {
258                    Ok(stop_reason) => match stop_reason {
259                        StopReason::ToolUse => {
260                            cx.emit(ThreadEvent::UsePendingTools);
261                        }
262                        StopReason::EndTurn => {}
263                        StopReason::MaxTokens => {}
264                    },
265                    Err(error) => {
266                        if error.is::<PaymentRequiredError>() {
267                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
268                        } else if error.is::<MaxMonthlySpendReachedError>() {
269                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
270                        } else {
271                            let error_message = error
272                                .chain()
273                                .map(|err| err.to_string())
274                                .collect::<Vec<_>>()
275                                .join("\n");
276                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
277                                SharedString::from(error_message.clone()),
278                            )));
279                        }
280                    }
281                })
282                .ok();
283        });
284
285        self.pending_completions.push(PendingCompletion {
286            id: pending_completion_id,
287            _task: task,
288        });
289    }
290
291    pub fn insert_tool_output(
292        &mut self,
293        assistant_message_id: MessageId,
294        tool_use_id: LanguageModelToolUseId,
295        output: Task<Result<String>>,
296        cx: &mut ModelContext<Self>,
297    ) {
298        let insert_output_task = cx.spawn(|thread, mut cx| {
299            let tool_use_id = tool_use_id.clone();
300            async move {
301                let output = output.await;
302                thread
303                    .update(&mut cx, |thread, cx| {
304                        // The tool use was requested by an Assistant message,
305                        // so we want to attach the tool results to the next
306                        // user message.
307                        let next_user_message = MessageId(assistant_message_id.0 + 1);
308
309                        let tool_results = thread
310                            .tool_results_by_message
311                            .entry(next_user_message)
312                            .or_default();
313
314                        match output {
315                            Ok(output) => {
316                                tool_results.push(LanguageModelToolResult {
317                                    tool_use_id: tool_use_id.to_string(),
318                                    content: output,
319                                    is_error: false,
320                                });
321
322                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
323                            }
324                            Err(err) => {
325                                tool_results.push(LanguageModelToolResult {
326                                    tool_use_id: tool_use_id.to_string(),
327                                    content: err.to_string(),
328                                    is_error: true,
329                                });
330
331                                if let Some(tool_use) =
332                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
333                                {
334                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
335                                }
336                            }
337                        }
338                    })
339                    .ok();
340            }
341        });
342
343        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
344            tool_use.status = PendingToolUseStatus::Running {
345                _task: insert_output_task.shared(),
346            };
347        }
348    }
349}
350
351#[derive(Debug, Clone)]
352pub enum ThreadError {
353    PaymentRequired,
354    MaxMonthlySpendReached,
355    Message(SharedString),
356}
357
358#[derive(Debug, Clone)]
359pub enum ThreadEvent {
360    ShowError(ThreadError),
361    StreamedCompletion,
362    StreamedAssistantText(MessageId, String),
363    MessageAdded(MessageId),
364    UsePendingTools,
365    ToolFinished {
366        #[allow(unused)]
367        tool_use_id: LanguageModelToolUseId,
368    },
369}
370
371impl EventEmitter<ThreadEvent> for Thread {}
372
373struct PendingCompletion {
374    id: usize,
375    _task: Task<()>,
376}
377
378#[derive(Debug, Clone)]
379pub struct PendingToolUse {
380    pub id: LanguageModelToolUseId,
381    /// The ID of the Assistant message in which the tool use was requested.
382    pub assistant_message_id: MessageId,
383    pub name: String,
384    pub input: serde_json::Value,
385    pub status: PendingToolUseStatus,
386}
387
388#[derive(Debug, Clone)]
389pub enum PendingToolUseStatus {
390    Idle,
391    Running { _task: Shared<Task<()>> },
392    Error(#[allow(unused)] String),
393}
394
395impl PendingToolUseStatus {
396    pub fn is_idle(&self) -> bool {
397        matches!(self, PendingToolUseStatus::Idle)
398    }
399}