thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
 12};
 13use util::post_inc;
 14
 15#[derive(Debug, Clone, Copy)]
 16pub enum RequestKind {
 17    Chat,
 18}
 19
 20/// A message in a [`Thread`].
 21#[derive(Debug, Clone)]
 22pub struct Message {
 23    pub role: Role,
 24    pub text: String,
 25    pub tool_uses: Vec<LanguageModelToolUse>,
 26    pub tool_results: Vec<LanguageModelToolResult>,
 27}
 28
 29/// A thread of conversation with the LLM.
 30pub struct Thread {
 31    messages: Vec<Message>,
 32    completion_count: usize,
 33    pending_completions: Vec<PendingCompletion>,
 34    tools: Arc<ToolWorkingSet>,
 35    pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
 36    completed_tool_uses_by_id: HashMap<Arc<str>, String>,
 37}
 38
 39impl Thread {
 40    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 41        Self {
 42            tools,
 43            messages: Vec::new(),
 44            completion_count: 0,
 45            pending_completions: Vec::new(),
 46            pending_tool_uses_by_id: HashMap::default(),
 47            completed_tool_uses_by_id: HashMap::default(),
 48        }
 49    }
 50
 51    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 52        self.messages.iter()
 53    }
 54
 55    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 56        &self.tools
 57    }
 58
 59    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
 60        self.pending_tool_uses_by_id.values().collect()
 61    }
 62
 63    pub fn insert_user_message(&mut self, text: impl Into<String>) {
 64        let mut message = Message {
 65            role: Role::User,
 66            text: text.into(),
 67            tool_uses: Vec::new(),
 68            tool_results: Vec::new(),
 69        };
 70
 71        for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
 72            message.tool_results.push(LanguageModelToolResult {
 73                tool_use_id: tool_use_id.to_string(),
 74                content: tool_output,
 75                is_error: false,
 76            });
 77        }
 78
 79        self.messages.push(message);
 80    }
 81
 82    pub fn to_completion_request(
 83        &self,
 84        _request_kind: RequestKind,
 85        _cx: &AppContext,
 86    ) -> LanguageModelRequest {
 87        let mut request = LanguageModelRequest {
 88            messages: vec![],
 89            tools: Vec::new(),
 90            stop: Vec::new(),
 91            temperature: None,
 92        };
 93
 94        for message in &self.messages {
 95            let mut request_message = LanguageModelRequestMessage {
 96                role: message.role,
 97                content: Vec::new(),
 98                cache: false,
 99            };
100
101            for tool_result in &message.tool_results {
102                request_message
103                    .content
104                    .push(MessageContent::ToolResult(tool_result.clone()));
105            }
106
107            if !message.text.is_empty() {
108                request_message
109                    .content
110                    .push(MessageContent::Text(message.text.clone()));
111            }
112
113            for tool_use in &message.tool_uses {
114                request_message
115                    .content
116                    .push(MessageContent::ToolUse(tool_use.clone()));
117            }
118
119            request.messages.push(request_message);
120        }
121
122        request
123    }
124
125    pub fn stream_completion(
126        &mut self,
127        request: LanguageModelRequest,
128        model: Arc<dyn LanguageModel>,
129        cx: &mut ModelContext<Self>,
130    ) {
131        let pending_completion_id = post_inc(&mut self.completion_count);
132
133        let task = cx.spawn(|thread, mut cx| async move {
134            let stream = model.stream_completion(request, &cx);
135            let stream_completion = async {
136                let mut events = stream.await?;
137                let mut stop_reason = StopReason::EndTurn;
138
139                while let Some(event) = events.next().await {
140                    let event = event?;
141
142                    thread.update(&mut cx, |thread, cx| {
143                        match event {
144                            LanguageModelCompletionEvent::StartMessage { .. } => {
145                                thread.messages.push(Message {
146                                    role: Role::Assistant,
147                                    text: String::new(),
148                                    tool_uses: Vec::new(),
149                                    tool_results: Vec::new(),
150                                });
151                            }
152                            LanguageModelCompletionEvent::Stop(reason) => {
153                                stop_reason = reason;
154                            }
155                            LanguageModelCompletionEvent::Text(chunk) => {
156                                if let Some(last_message) = thread.messages.last_mut() {
157                                    if last_message.role == Role::Assistant {
158                                        last_message.text.push_str(&chunk);
159                                    }
160                                }
161                            }
162                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
163                                if let Some(last_message) = thread.messages.last_mut() {
164                                    if last_message.role == Role::Assistant {
165                                        last_message.tool_uses.push(tool_use.clone());
166                                    }
167                                }
168
169                                let tool_use_id: Arc<str> = tool_use.id.into();
170                                thread.pending_tool_uses_by_id.insert(
171                                    tool_use_id.clone(),
172                                    PendingToolUse {
173                                        id: tool_use_id,
174                                        name: tool_use.name,
175                                        input: tool_use.input,
176                                        status: PendingToolUseStatus::Idle,
177                                    },
178                                );
179                            }
180                        }
181
182                        cx.emit(ThreadEvent::StreamedCompletion);
183                        cx.notify();
184                    })?;
185
186                    smol::future::yield_now().await;
187                }
188
189                thread.update(&mut cx, |thread, _cx| {
190                    thread
191                        .pending_completions
192                        .retain(|completion| completion.id != pending_completion_id);
193                })?;
194
195                anyhow::Ok(stop_reason)
196            };
197
198            let result = stream_completion.await;
199
200            thread
201                .update(&mut cx, |_thread, cx| {
202                    let error_message = if let Some(error) = result.as_ref().err() {
203                        let error_message = error
204                            .chain()
205                            .map(|err| err.to_string())
206                            .collect::<Vec<_>>()
207                            .join("\n");
208                        Some(error_message)
209                    } else {
210                        None
211                    };
212
213                    if let Some(error_message) = error_message {
214                        eprintln!("Completion failed: {error_message:?}");
215                    }
216
217                    if let Ok(stop_reason) = result {
218                        match stop_reason {
219                            StopReason::ToolUse => {
220                                cx.emit(ThreadEvent::UsePendingTools);
221                            }
222                            StopReason::EndTurn => {}
223                            StopReason::MaxTokens => {}
224                        }
225                    }
226                })
227                .ok();
228        });
229
230        self.pending_completions.push(PendingCompletion {
231            id: pending_completion_id,
232            _task: task,
233        });
234    }
235
236    pub fn insert_tool_output(
237        &mut self,
238        tool_use_id: Arc<str>,
239        output: Task<Result<String>>,
240        cx: &mut ModelContext<Self>,
241    ) {
242        let insert_output_task = cx.spawn(|thread, mut cx| {
243            let tool_use_id = tool_use_id.clone();
244            async move {
245                let output = output.await;
246                thread
247                    .update(&mut cx, |thread, cx| match output {
248                        Ok(output) => {
249                            thread
250                                .completed_tool_uses_by_id
251                                .insert(tool_use_id.clone(), output);
252
253                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
254                        }
255                        Err(err) => {
256                            if let Some(tool_use) =
257                                thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
258                            {
259                                tool_use.status = PendingToolUseStatus::Error(err.to_string());
260                            }
261                        }
262                    })
263                    .ok();
264            }
265        });
266
267        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
268            tool_use.status = PendingToolUseStatus::Running {
269                _task: insert_output_task.shared(),
270            };
271        }
272    }
273}
274
275#[derive(Debug, Clone)]
276pub enum ThreadEvent {
277    StreamedCompletion,
278    UsePendingTools,
279    ToolFinished {
280        #[allow(unused)]
281        tool_use_id: Arc<str>,
282    },
283}
284
285impl EventEmitter<ThreadEvent> for Thread {}
286
287struct PendingCompletion {
288    id: usize,
289    _task: Task<()>,
290}
291
292#[derive(Debug, Clone)]
293pub struct PendingToolUse {
294    pub id: Arc<str>,
295    pub name: String,
296    pub input: serde_json::Value,
297    pub status: PendingToolUseStatus,
298}
299
300#[derive(Debug, Clone)]
301pub enum PendingToolUseStatus {
302    Idle,
303    Running { _task: Shared<Task<()>> },
304    Error(#[allow(unused)] String),
305}
306
307impl PendingToolUseStatus {
308    pub fn is_idle(&self) -> bool {
309        matches!(self, PendingToolUseStatus::Idle)
310    }
311}