thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12    StopReason,
 13};
 14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
 15use serde::{Deserialize, Serialize};
 16use util::post_inc;
 17use uuid::Uuid;
 18
 19#[derive(Debug, Clone, Copy)]
 20pub enum RequestKind {
 21    Chat,
 22}
 23
 24#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 25pub struct ThreadId(Arc<str>);
 26
 27impl ThreadId {
 28    pub fn new() -> Self {
 29        Self(Uuid::new_v4().to_string().into())
 30    }
 31}
 32
 33impl std::fmt::Display for ThreadId {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        write!(f, "{}", self.0)
 36    }
 37}
 38
 39#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 40pub struct MessageId(usize);
 41
 42impl MessageId {
 43    fn post_inc(&mut self) -> Self {
 44        Self(post_inc(&mut self.0))
 45    }
 46}
 47
 48/// A message in a [`Thread`].
 49#[derive(Debug, Clone)]
 50pub struct Message {
 51    pub id: MessageId,
 52    pub role: Role,
 53    pub text: String,
 54}
 55
 56/// A thread of conversation with the LLM.
 57pub struct Thread {
 58    id: ThreadId,
 59    messages: Vec<Message>,
 60    next_message_id: MessageId,
 61    completion_count: usize,
 62    pending_completions: Vec<PendingCompletion>,
 63    tools: Arc<ToolWorkingSet>,
 64    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 65    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
 66    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 67}
 68
 69impl Thread {
 70    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 71        Self {
 72            id: ThreadId::new(),
 73            messages: Vec::new(),
 74            next_message_id: MessageId(0),
 75            completion_count: 0,
 76            pending_completions: Vec::new(),
 77            tools,
 78            tool_uses_by_message: HashMap::default(),
 79            tool_results_by_message: HashMap::default(),
 80            pending_tool_uses_by_id: HashMap::default(),
 81        }
 82    }
 83
 84    pub fn id(&self) -> &ThreadId {
 85        &self.id
 86    }
 87
 88    pub fn is_empty(&self) -> bool {
 89        self.messages.is_empty()
 90    }
 91
 92    pub fn message(&self, id: MessageId) -> Option<&Message> {
 93        self.messages.iter().find(|message| message.id == id)
 94    }
 95
 96    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 97        self.messages.iter()
 98    }
 99
100    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
101        &self.tools
102    }
103
104    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
105        self.pending_tool_uses_by_id.values().collect()
106    }
107
108    pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
109        self.insert_message(Role::User, text, cx)
110    }
111
112    pub fn insert_message(
113        &mut self,
114        role: Role,
115        text: impl Into<String>,
116        cx: &mut ModelContext<Self>,
117    ) {
118        let id = self.next_message_id.post_inc();
119        self.messages.push(Message {
120            id,
121            role,
122            text: text.into(),
123        });
124        cx.emit(ThreadEvent::MessageAdded(id));
125    }
126
127    pub fn to_completion_request(
128        &self,
129        _request_kind: RequestKind,
130        _cx: &AppContext,
131    ) -> LanguageModelRequest {
132        let mut request = LanguageModelRequest {
133            messages: vec![],
134            tools: Vec::new(),
135            stop: Vec::new(),
136            temperature: None,
137        };
138
139        for message in &self.messages {
140            let mut request_message = LanguageModelRequestMessage {
141                role: message.role,
142                content: Vec::new(),
143                cache: false,
144            };
145
146            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
147                for tool_result in tool_results {
148                    request_message
149                        .content
150                        .push(MessageContent::ToolResult(tool_result.clone()));
151                }
152            }
153
154            if !message.text.is_empty() {
155                request_message
156                    .content
157                    .push(MessageContent::Text(message.text.clone()));
158            }
159
160            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
161                for tool_use in tool_uses {
162                    request_message
163                        .content
164                        .push(MessageContent::ToolUse(tool_use.clone()));
165                }
166            }
167
168            request.messages.push(request_message);
169        }
170
171        request
172    }
173
174    pub fn stream_completion(
175        &mut self,
176        request: LanguageModelRequest,
177        model: Arc<dyn LanguageModel>,
178        cx: &mut ModelContext<Self>,
179    ) {
180        let pending_completion_id = post_inc(&mut self.completion_count);
181
182        let task = cx.spawn(|thread, mut cx| async move {
183            let stream = model.stream_completion(request, &cx);
184            let stream_completion = async {
185                let mut events = stream.await?;
186                let mut stop_reason = StopReason::EndTurn;
187
188                while let Some(event) = events.next().await {
189                    let event = event?;
190
191                    thread.update(&mut cx, |thread, cx| {
192                        match event {
193                            LanguageModelCompletionEvent::StartMessage { .. } => {
194                                let id = thread.next_message_id.post_inc();
195                                thread.messages.push(Message {
196                                    id,
197                                    role: Role::Assistant,
198                                    text: String::new(),
199                                });
200                                cx.emit(ThreadEvent::MessageAdded(id));
201                            }
202                            LanguageModelCompletionEvent::Stop(reason) => {
203                                stop_reason = reason;
204                            }
205                            LanguageModelCompletionEvent::Text(chunk) => {
206                                if let Some(last_message) = thread.messages.last_mut() {
207                                    if last_message.role == Role::Assistant {
208                                        last_message.text.push_str(&chunk);
209                                        cx.emit(ThreadEvent::StreamedAssistantText(
210                                            last_message.id,
211                                            chunk,
212                                        ));
213                                    }
214                                }
215                            }
216                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
217                                if let Some(last_assistant_message) = thread
218                                    .messages
219                                    .iter()
220                                    .rfind(|message| message.role == Role::Assistant)
221                                {
222                                    thread
223                                        .tool_uses_by_message
224                                        .entry(last_assistant_message.id)
225                                        .or_default()
226                                        .push(tool_use.clone());
227
228                                    thread.pending_tool_uses_by_id.insert(
229                                        tool_use.id.clone(),
230                                        PendingToolUse {
231                                            assistant_message_id: last_assistant_message.id,
232                                            id: tool_use.id,
233                                            name: tool_use.name,
234                                            input: tool_use.input,
235                                            status: PendingToolUseStatus::Idle,
236                                        },
237                                    );
238                                }
239                            }
240                        }
241
242                        cx.emit(ThreadEvent::StreamedCompletion);
243                        cx.notify();
244                    })?;
245
246                    smol::future::yield_now().await;
247                }
248
249                thread.update(&mut cx, |thread, _cx| {
250                    thread
251                        .pending_completions
252                        .retain(|completion| completion.id != pending_completion_id);
253                })?;
254
255                anyhow::Ok(stop_reason)
256            };
257
258            let result = stream_completion.await;
259
260            thread
261                .update(&mut cx, |_thread, cx| match result.as_ref() {
262                    Ok(stop_reason) => match stop_reason {
263                        StopReason::ToolUse => {
264                            cx.emit(ThreadEvent::UsePendingTools);
265                        }
266                        StopReason::EndTurn => {}
267                        StopReason::MaxTokens => {}
268                    },
269                    Err(error) => {
270                        if error.is::<PaymentRequiredError>() {
271                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
272                        } else if error.is::<MaxMonthlySpendReachedError>() {
273                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
274                        } else {
275                            let error_message = error
276                                .chain()
277                                .map(|err| err.to_string())
278                                .collect::<Vec<_>>()
279                                .join("\n");
280                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
281                                SharedString::from(error_message.clone()),
282                            )));
283                        }
284                    }
285                })
286                .ok();
287        });
288
289        self.pending_completions.push(PendingCompletion {
290            id: pending_completion_id,
291            _task: task,
292        });
293    }
294
295    pub fn insert_tool_output(
296        &mut self,
297        assistant_message_id: MessageId,
298        tool_use_id: LanguageModelToolUseId,
299        output: Task<Result<String>>,
300        cx: &mut ModelContext<Self>,
301    ) {
302        let insert_output_task = cx.spawn(|thread, mut cx| {
303            let tool_use_id = tool_use_id.clone();
304            async move {
305                let output = output.await;
306                thread
307                    .update(&mut cx, |thread, cx| {
308                        // The tool use was requested by an Assistant message,
309                        // so we want to attach the tool results to the next
310                        // user message.
311                        let next_user_message = MessageId(assistant_message_id.0 + 1);
312
313                        let tool_results = thread
314                            .tool_results_by_message
315                            .entry(next_user_message)
316                            .or_default();
317
318                        match output {
319                            Ok(output) => {
320                                tool_results.push(LanguageModelToolResult {
321                                    tool_use_id: tool_use_id.to_string(),
322                                    content: output,
323                                    is_error: false,
324                                });
325
326                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
327                            }
328                            Err(err) => {
329                                tool_results.push(LanguageModelToolResult {
330                                    tool_use_id: tool_use_id.to_string(),
331                                    content: err.to_string(),
332                                    is_error: true,
333                                });
334
335                                if let Some(tool_use) =
336                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
337                                {
338                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
339                                }
340                            }
341                        }
342                    })
343                    .ok();
344            }
345        });
346
347        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
348            tool_use.status = PendingToolUseStatus::Running {
349                _task: insert_output_task.shared(),
350            };
351        }
352    }
353}
354
355#[derive(Debug, Clone)]
356pub enum ThreadError {
357    PaymentRequired,
358    MaxMonthlySpendReached,
359    Message(SharedString),
360}
361
362#[derive(Debug, Clone)]
363pub enum ThreadEvent {
364    ShowError(ThreadError),
365    StreamedCompletion,
366    StreamedAssistantText(MessageId, String),
367    MessageAdded(MessageId),
368    UsePendingTools,
369    ToolFinished {
370        #[allow(unused)]
371        tool_use_id: LanguageModelToolUseId,
372    },
373}
374
375impl EventEmitter<ThreadEvent> for Thread {}
376
377struct PendingCompletion {
378    id: usize,
379    _task: Task<()>,
380}
381
382#[derive(Debug, Clone)]
383pub struct PendingToolUse {
384    pub id: LanguageModelToolUseId,
385    /// The ID of the Assistant message in which the tool use was requested.
386    pub assistant_message_id: MessageId,
387    pub name: String,
388    pub input: serde_json::Value,
389    pub status: PendingToolUseStatus,
390}
391
392#[derive(Debug, Clone)]
393pub enum PendingToolUseStatus {
394    Idle,
395    Running { _task: Shared<Task<()>> },
396    Error(#[allow(unused)] String),
397}
398
399impl PendingToolUseStatus {
400    pub fn is_idle(&self) -> bool {
401        matches!(self, PendingToolUseStatus::Idle)
402    }
403}