thread.rs

  1use crate::templates::Templates;
  2use anyhow::{anyhow, Result};
  3use futures::{channel::mpsc, future};
  4use gpui::{App, Context, SharedString, Task};
  5use language_model::{
  6    CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError,
  7    LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
  8    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
  9    LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason,
 10};
 11use schemars::{JsonSchema, Schema};
 12use serde::Deserialize;
 13use smol::stream::StreamExt;
 14use std::{collections::BTreeMap, sync::Arc};
 15use util::ResultExt;
 16
 17#[derive(Debug)]
 18pub struct AgentMessage {
 19    pub role: Role,
 20    pub content: Vec<MessageContent>,
 21}
 22
 23pub type AgentResponseEvent = LanguageModelCompletionEvent;
 24
 25pub trait Prompt {
 26    fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
 27}
 28
 29pub struct Thread {
 30    messages: Vec<AgentMessage>,
 31    completion_mode: CompletionMode,
 32    /// Holds the task that handles agent interaction until the end of the turn.
 33    /// Survives across multiple requests as the model performs tool calls and
 34    /// we run tools, report their results.
 35    running_turn: Option<Task<()>>,
 36    system_prompts: Vec<Arc<dyn Prompt>>,
 37    tools: BTreeMap<SharedString, Arc<dyn AgentToolErased>>,
 38    templates: Arc<Templates>,
 39    // project: Entity<Project>,
 40    // action_log: Entity<ActionLog>,
 41}
 42
 43impl Thread {
 44    pub fn new(templates: Arc<Templates>) -> Self {
 45        Self {
 46            messages: Vec::new(),
 47            completion_mode: CompletionMode::Normal,
 48            system_prompts: Vec::new(),
 49            running_turn: None,
 50            tools: BTreeMap::default(),
 51            templates,
 52        }
 53    }
 54
 55    pub fn set_mode(&mut self, mode: CompletionMode) {
 56        self.completion_mode = mode;
 57    }
 58
 59    pub fn messages(&self) -> &[AgentMessage] {
 60        &self.messages
 61    }
 62
 63    pub fn add_tool(&mut self, tool: impl AgentTool) {
 64        self.tools.insert(tool.name(), tool.erase());
 65    }
 66
 67    pub fn remove_tool(&mut self, name: &str) -> bool {
 68        self.tools.remove(name).is_some()
 69    }
 70
 71    /// Sending a message results in the model streaming a response, which could include tool calls.
 72    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 73    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 74    pub fn send(
 75        &mut self,
 76        model: Arc<dyn LanguageModel>,
 77        content: impl Into<MessageContent>,
 78        cx: &mut Context<Self>,
 79    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 80        cx.notify();
 81        let (events_tx, events_rx) =
 82            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 83
 84        let system_message = self.build_system_message(cx);
 85        self.messages.extend(system_message);
 86
 87        self.messages.push(AgentMessage {
 88            role: Role::User,
 89            content: vec![content.into()],
 90        });
 91        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 92            let turn_result = async {
 93                // Perform one request, then keep looping if the model makes tool calls.
 94                let mut completion_intent = CompletionIntent::UserPrompt;
 95                loop {
 96                    let request = thread.update(cx, |thread, cx| {
 97                        thread.build_completion_request(completion_intent, cx)
 98                    })?;
 99
100                    // println!(
101                    //     "request: {}",
102                    //     serde_json::to_string_pretty(&request).unwrap()
103                    // );
104
105                    // Stream events, appending to messages and collecting up tool uses.
106                    let mut events = model.stream_completion(request, cx).await?;
107                    let mut tool_uses = Vec::new();
108                    while let Some(event) = events.next().await {
109                        match event {
110                            Ok(event) => {
111                                thread
112                                    .update(cx, |thread, cx| {
113                                        tool_uses.extend(thread.handle_streamed_completion_event(
114                                            event,
115                                            events_tx.clone(),
116                                            cx,
117                                        ));
118                                    })
119                                    .ok();
120                            }
121                            Err(error) => {
122                                events_tx.unbounded_send(Err(error)).ok();
123                                break;
124                            }
125                        }
126                    }
127
128                    // If there are no tool uses, the turn is done.
129                    if tool_uses.is_empty() {
130                        break;
131                    }
132
133                    // If there are tool uses, wait for their results to be
134                    // computed, then send them together in a single message on
135                    // the next loop iteration.
136                    let tool_results = future::join_all(tool_uses).await;
137                    thread
138                        .update(cx, |thread, _cx| {
139                            thread.messages.push(AgentMessage {
140                                role: Role::User,
141                                content: tool_results.into_iter().map(Into::into).collect(),
142                            });
143                        })
144                        .ok();
145                    completion_intent = CompletionIntent::ToolResults;
146                }
147
148                Ok(())
149            }
150            .await;
151
152            if let Err(error) = turn_result {
153                events_tx.unbounded_send(Err(error)).ok();
154            }
155        }));
156        events_rx
157    }
158
159    pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
160        let mut system_message = AgentMessage {
161            role: Role::System,
162            content: Vec::new(),
163        };
164
165        for prompt in &self.system_prompts {
166            if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
167                system_message
168                    .content
169                    .push(MessageContent::Text(rendered_prompt));
170            }
171        }
172
173        (!system_message.content.is_empty()).then_some(system_message)
174    }
175
176    /// A helper method that's called on every streamed completion event.
177    /// Returns an optional tool result task, which the main agentic loop in
178    /// send will send back to the model when it resolves.
179    fn handle_streamed_completion_event(
180        &mut self,
181        event: LanguageModelCompletionEvent,
182        events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
183        cx: &mut Context<Self>,
184    ) -> Option<Task<LanguageModelToolResult>> {
185        use LanguageModelCompletionEvent::*;
186        events_tx.unbounded_send(Ok(event.clone())).ok();
187
188        match event {
189            Text(new_text) => self.handle_text_event(new_text, cx),
190            Thinking { text, signature } => {
191                todo!()
192            }
193            ToolUse(tool_use) => {
194                return self.handle_tool_use_event(tool_use, cx);
195            }
196            StartMessage { role, .. } => {
197                self.messages.push(AgentMessage {
198                    role,
199                    content: Vec::new(),
200                });
201            }
202            UsageUpdate(_) => {}
203            Stop(stop_reason) => self.handle_stop_event(stop_reason),
204            StatusUpdate(_completion_request_status) => {}
205            RedactedThinking { data } => todo!(),
206            ToolUseJsonParseError {
207                id,
208                tool_name,
209                raw_input,
210                json_parse_error,
211            } => todo!(),
212        }
213
214        None
215    }
216
217    fn handle_stop_event(&mut self, stop_reason: StopReason) {
218        match stop_reason {
219            StopReason::EndTurn | StopReason::ToolUse => {}
220            StopReason::MaxTokens => todo!(),
221            StopReason::Refusal => todo!(),
222        }
223    }
224
225    fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
226        let last_message = self.last_assistant_message();
227        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
228            text.push_str(&new_text);
229        } else {
230            last_message.content.push(MessageContent::Text(new_text));
231        }
232
233        cx.notify();
234    }
235
236    fn handle_tool_use_event(
237        &mut self,
238        tool_use: LanguageModelToolUse,
239        cx: &mut Context<Self>,
240    ) -> Option<Task<LanguageModelToolResult>> {
241        cx.notify();
242
243        let last_message = self.last_assistant_message();
244
245        // Ensure the last message ends in the current tool use
246        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
247            if let MessageContent::ToolUse(last_tool_use) = content {
248                if last_tool_use.id == tool_use.id {
249                    *last_tool_use = tool_use.clone();
250                    false
251                } else {
252                    true
253                }
254            } else {
255                true
256            }
257        });
258        if push_new_tool_use {
259            last_message.content.push(tool_use.clone().into());
260        }
261
262        if !tool_use.is_input_complete {
263            return None;
264        }
265
266        if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
267            let pending_tool_result = tool.clone().run(tool_use.input, cx);
268
269            Some(cx.foreground_executor().spawn(async move {
270                match pending_tool_result.await {
271                    Ok(tool_output) => LanguageModelToolResult {
272                        tool_use_id: tool_use.id,
273                        tool_name: tool_use.name,
274                        is_error: false,
275                        content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
276                        output: None,
277                    },
278                    Err(error) => LanguageModelToolResult {
279                        tool_use_id: tool_use.id,
280                        tool_name: tool_use.name,
281                        is_error: true,
282                        content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
283                        output: None,
284                    },
285                }
286            }))
287        } else {
288            Some(Task::ready(LanguageModelToolResult {
289                content: LanguageModelToolResultContent::Text(Arc::from(format!(
290                    "No tool named {} exists",
291                    tool_use.name
292                ))),
293                tool_use_id: tool_use.id,
294                tool_name: tool_use.name,
295                is_error: true,
296                output: None,
297            }))
298        }
299    }
300
301    /// Guarantees the last message is from the assistant and returns a mutable reference.
302    fn last_assistant_message(&mut self) -> &mut AgentMessage {
303        if self
304            .messages
305            .last()
306            .map_or(true, |m| m.role != Role::Assistant)
307        {
308            self.messages.push(AgentMessage {
309                role: Role::Assistant,
310                content: Vec::new(),
311            });
312        }
313        self.messages.last_mut().unwrap()
314    }
315
316    fn build_completion_request(
317        &self,
318        completion_intent: CompletionIntent,
319        cx: &mut App,
320    ) -> LanguageModelRequest {
321        LanguageModelRequest {
322            thread_id: None,
323            prompt_id: None,
324            intent: Some(completion_intent),
325            mode: Some(self.completion_mode),
326            messages: self.build_request_messages(),
327            tools: self
328                .tools
329                .values()
330                .filter_map(|tool| {
331                    Some(LanguageModelRequestTool {
332                        name: tool.name().to_string(),
333                        description: tool.description(cx).to_string(),
334                        input_schema: tool
335                            .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
336                            .log_err()?,
337                    })
338                })
339                .collect(),
340            tool_choice: None,
341            stop: Vec::new(),
342            temperature: None,
343        }
344    }
345
346    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
347        self.messages
348            .iter()
349            .map(|message| LanguageModelRequestMessage {
350                role: message.role,
351                content: message.content.clone(),
352                cache: false,
353            })
354            .collect()
355    }
356}
357
358pub trait AgentTool
359where
360    Self: 'static + Sized,
361{
362    type Input: for<'de> Deserialize<'de> + JsonSchema;
363
364    fn name(&self) -> SharedString;
365    fn description(&self, _cx: &mut App) -> SharedString {
366        let schema = schemars::schema_for!(Self::Input);
367        SharedString::new(
368            schema
369                .get("description")
370                .and_then(|description| description.as_str())
371                .unwrap_or_default(),
372        )
373    }
374
375    /// Returns the JSON schema that describes the tool's input.
376    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
377        assistant_tools::root_schema_for::<Self::Input>(format)
378    }
379
380    /// Runs the tool with the provided input.
381    fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
382
383    fn erase(self) -> Arc<dyn AgentToolErased> {
384        Arc::new(Erased(Arc::new(self)))
385    }
386}
387
388pub struct Erased<T>(T);
389
390pub trait AgentToolErased {
391    fn name(&self) -> SharedString;
392    fn description(&self, cx: &mut App) -> SharedString;
393    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
394    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
395}
396
397impl<T> AgentToolErased for Erased<Arc<T>>
398where
399    T: AgentTool,
400{
401    fn name(&self) -> SharedString {
402        self.0.name()
403    }
404
405    fn description(&self, cx: &mut App) -> SharedString {
406        self.0.description(cx)
407    }
408
409    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
410        Ok(serde_json::to_value(self.0.input_schema(format))?)
411    }
412
413    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
414        let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
415        match parsed_input {
416            Ok(input) => self.0.clone().run(input, cx),
417            Err(error) => Task::ready(Err(anyhow!(error))),
418        }
419    }
420}