thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12    StopReason,
 13};
 14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
 15use serde::{Deserialize, Serialize};
 16use util::post_inc;
 17
 18#[derive(Debug, Clone, Copy)]
 19pub enum RequestKind {
 20    Chat,
 21}
 22
 23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 24pub struct MessageId(usize);
 25
 26impl MessageId {
 27    fn post_inc(&mut self) -> Self {
 28        Self(post_inc(&mut self.0))
 29    }
 30}
 31
 32/// A message in a [`Thread`].
 33#[derive(Debug, Clone)]
 34pub struct Message {
 35    pub id: MessageId,
 36    pub role: Role,
 37    pub text: String,
 38}
 39
 40/// A thread of conversation with the LLM.
 41pub struct Thread {
 42    messages: Vec<Message>,
 43    next_message_id: MessageId,
 44    completion_count: usize,
 45    pending_completions: Vec<PendingCompletion>,
 46    tools: Arc<ToolWorkingSet>,
 47    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 48    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
 49    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 50}
 51
 52impl Thread {
 53    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 54        Self {
 55            messages: Vec::new(),
 56            next_message_id: MessageId(0),
 57            completion_count: 0,
 58            pending_completions: Vec::new(),
 59            tools,
 60            tool_uses_by_message: HashMap::default(),
 61            tool_results_by_message: HashMap::default(),
 62            pending_tool_uses_by_id: HashMap::default(),
 63        }
 64    }
 65
 66    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 67        self.messages.iter()
 68    }
 69
 70    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 71        &self.tools
 72    }
 73
 74    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
 75        self.pending_tool_uses_by_id.values().collect()
 76    }
 77
 78    pub fn insert_user_message(&mut self, text: impl Into<String>) {
 79        self.messages.push(Message {
 80            id: self.next_message_id.post_inc(),
 81            role: Role::User,
 82            text: text.into(),
 83        });
 84    }
 85
 86    pub fn to_completion_request(
 87        &self,
 88        _request_kind: RequestKind,
 89        _cx: &AppContext,
 90    ) -> LanguageModelRequest {
 91        let mut request = LanguageModelRequest {
 92            messages: vec![],
 93            tools: Vec::new(),
 94            stop: Vec::new(),
 95            temperature: None,
 96        };
 97
 98        for message in &self.messages {
 99            let mut request_message = LanguageModelRequestMessage {
100                role: message.role,
101                content: Vec::new(),
102                cache: false,
103            };
104
105            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
106                for tool_result in tool_results {
107                    request_message
108                        .content
109                        .push(MessageContent::ToolResult(tool_result.clone()));
110                }
111            }
112
113            if !message.text.is_empty() {
114                request_message
115                    .content
116                    .push(MessageContent::Text(message.text.clone()));
117            }
118
119            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
120                for tool_use in tool_uses {
121                    request_message
122                        .content
123                        .push(MessageContent::ToolUse(tool_use.clone()));
124                }
125            }
126
127            request.messages.push(request_message);
128        }
129
130        request
131    }
132
133    pub fn stream_completion(
134        &mut self,
135        request: LanguageModelRequest,
136        model: Arc<dyn LanguageModel>,
137        cx: &mut ModelContext<Self>,
138    ) {
139        let pending_completion_id = post_inc(&mut self.completion_count);
140
141        let task = cx.spawn(|thread, mut cx| async move {
142            let stream = model.stream_completion(request, &cx);
143            let stream_completion = async {
144                let mut events = stream.await?;
145                let mut stop_reason = StopReason::EndTurn;
146
147                while let Some(event) = events.next().await {
148                    let event = event?;
149
150                    thread.update(&mut cx, |thread, cx| {
151                        match event {
152                            LanguageModelCompletionEvent::StartMessage { .. } => {
153                                thread.messages.push(Message {
154                                    id: thread.next_message_id.post_inc(),
155                                    role: Role::Assistant,
156                                    text: String::new(),
157                                });
158                            }
159                            LanguageModelCompletionEvent::Stop(reason) => {
160                                stop_reason = reason;
161                            }
162                            LanguageModelCompletionEvent::Text(chunk) => {
163                                if let Some(last_message) = thread.messages.last_mut() {
164                                    if last_message.role == Role::Assistant {
165                                        last_message.text.push_str(&chunk);
166                                    }
167                                }
168                            }
169                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
170                                if let Some(last_assistant_message) = thread
171                                    .messages
172                                    .iter()
173                                    .rfind(|message| message.role == Role::Assistant)
174                                {
175                                    thread
176                                        .tool_uses_by_message
177                                        .entry(last_assistant_message.id)
178                                        .or_default()
179                                        .push(tool_use.clone());
180
181                                    thread.pending_tool_uses_by_id.insert(
182                                        tool_use.id.clone(),
183                                        PendingToolUse {
184                                            assistant_message_id: last_assistant_message.id,
185                                            id: tool_use.id,
186                                            name: tool_use.name,
187                                            input: tool_use.input,
188                                            status: PendingToolUseStatus::Idle,
189                                        },
190                                    );
191                                }
192                            }
193                        }
194
195                        cx.emit(ThreadEvent::StreamedCompletion);
196                        cx.notify();
197                    })?;
198
199                    smol::future::yield_now().await;
200                }
201
202                thread.update(&mut cx, |thread, _cx| {
203                    thread
204                        .pending_completions
205                        .retain(|completion| completion.id != pending_completion_id);
206                })?;
207
208                anyhow::Ok(stop_reason)
209            };
210
211            let result = stream_completion.await;
212
213            thread
214                .update(&mut cx, |_thread, cx| match result.as_ref() {
215                    Ok(stop_reason) => match stop_reason {
216                        StopReason::ToolUse => {
217                            cx.emit(ThreadEvent::UsePendingTools);
218                        }
219                        StopReason::EndTurn => {}
220                        StopReason::MaxTokens => {}
221                    },
222                    Err(error) => {
223                        if error.is::<PaymentRequiredError>() {
224                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
225                        } else if error.is::<MaxMonthlySpendReachedError>() {
226                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
227                        } else {
228                            let error_message = error
229                                .chain()
230                                .map(|err| err.to_string())
231                                .collect::<Vec<_>>()
232                                .join("\n");
233                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
234                                SharedString::from(error_message.clone()),
235                            )));
236                        }
237                    }
238                })
239                .ok();
240        });
241
242        self.pending_completions.push(PendingCompletion {
243            id: pending_completion_id,
244            _task: task,
245        });
246    }
247
248    pub fn insert_tool_output(
249        &mut self,
250        assistant_message_id: MessageId,
251        tool_use_id: LanguageModelToolUseId,
252        output: Task<Result<String>>,
253        cx: &mut ModelContext<Self>,
254    ) {
255        let insert_output_task = cx.spawn(|thread, mut cx| {
256            let tool_use_id = tool_use_id.clone();
257            async move {
258                let output = output.await;
259                thread
260                    .update(&mut cx, |thread, cx| {
261                        // The tool use was requested by an Assistant message,
262                        // so we want to attach the tool results to the next
263                        // user message.
264                        let next_user_message = MessageId(assistant_message_id.0 + 1);
265
266                        let tool_results = thread
267                            .tool_results_by_message
268                            .entry(next_user_message)
269                            .or_default();
270
271                        match output {
272                            Ok(output) => {
273                                tool_results.push(LanguageModelToolResult {
274                                    tool_use_id: tool_use_id.to_string(),
275                                    content: output,
276                                    is_error: false,
277                                });
278
279                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
280                            }
281                            Err(err) => {
282                                tool_results.push(LanguageModelToolResult {
283                                    tool_use_id: tool_use_id.to_string(),
284                                    content: err.to_string(),
285                                    is_error: true,
286                                });
287
288                                if let Some(tool_use) =
289                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
290                                {
291                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
292                                }
293                            }
294                        }
295                    })
296                    .ok();
297            }
298        });
299
300        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
301            tool_use.status = PendingToolUseStatus::Running {
302                _task: insert_output_task.shared(),
303            };
304        }
305    }
306}
307
308#[derive(Debug, Clone)]
309pub enum ThreadError {
310    PaymentRequired,
311    MaxMonthlySpendReached,
312    Message(SharedString),
313}
314
315#[derive(Debug, Clone)]
316pub enum ThreadEvent {
317    ShowError(ThreadError),
318    StreamedCompletion,
319    UsePendingTools,
320    ToolFinished {
321        #[allow(unused)]
322        tool_use_id: LanguageModelToolUseId,
323    },
324}
325
326impl EventEmitter<ThreadEvent> for Thread {}
327
328struct PendingCompletion {
329    id: usize,
330    _task: Task<()>,
331}
332
333#[derive(Debug, Clone)]
334pub struct PendingToolUse {
335    pub id: LanguageModelToolUseId,
336    /// The ID of the Assistant message in which the tool use was requested.
337    pub assistant_message_id: MessageId,
338    pub name: String,
339    pub input: serde_json::Value,
340    pub status: PendingToolUseStatus,
341}
342
343#[derive(Debug, Clone)]
344pub enum PendingToolUseStatus {
345    Idle,
346    Running { _task: Shared<Task<()>> },
347    Error(#[allow(unused)] String),
348}
349
350impl PendingToolUseStatus {
351    pub fn is_idle(&self) -> bool {
352        matches!(self, PendingToolUseStatus::Idle)
353    }
354}