thread.rs

  1use crate::templates::Templates;
  2use anyhow::{anyhow, Result};
  3use cloud_llm_client::{CompletionIntent, CompletionMode};
  4use futures::{channel::mpsc, future};
  5use gpui::{App, Context, SharedString, Task};
  6use language_model::{
  7    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  8    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  9    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 10    LanguageModelToolUse, MessageContent, Role, StopReason,
 11};
 12use schemars::{JsonSchema, Schema};
 13use serde::Deserialize;
 14use smol::stream::StreamExt;
 15use std::{collections::BTreeMap, sync::Arc};
 16use util::ResultExt;
 17
 18#[derive(Debug)]
 19pub struct AgentMessage {
 20    pub role: Role,
 21    pub content: Vec<MessageContent>,
 22}
 23
 24pub type AgentResponseEvent = LanguageModelCompletionEvent;
 25
 26pub trait Prompt {
 27    fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
 28}
 29
 30pub struct Thread {
 31    messages: Vec<AgentMessage>,
 32    completion_mode: CompletionMode,
 33    /// Holds the task that handles agent interaction until the end of the turn.
 34    /// Survives across multiple requests as the model performs tool calls and
 35    /// we run tools, report their results.
 36    running_turn: Option<Task<()>>,
 37    system_prompts: Vec<Arc<dyn Prompt>>,
 38    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 39    templates: Arc<Templates>,
 40    pub selected_model: Arc<dyn LanguageModel>,
 41    // project: Entity<Project>,
 42    // action_log: Entity<ActionLog>,
 43}
 44
 45impl Thread {
 46    pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
 47        Self {
 48            messages: Vec::new(),
 49            completion_mode: CompletionMode::Normal,
 50            system_prompts: Vec::new(),
 51            running_turn: None,
 52            tools: BTreeMap::default(),
 53            templates,
 54            selected_model: default_model,
 55        }
 56    }
 57
 58    pub fn set_mode(&mut self, mode: CompletionMode) {
 59        self.completion_mode = mode;
 60    }
 61
 62    pub fn messages(&self) -> &[AgentMessage] {
 63        &self.messages
 64    }
 65
 66    pub fn add_tool(&mut self, tool: impl AgentTool) {
 67        self.tools.insert(tool.name(), tool.erase());
 68    }
 69
 70    pub fn remove_tool(&mut self, name: &str) -> bool {
 71        self.tools.remove(name).is_some()
 72    }
 73
 74    /// Sending a message results in the model streaming a response, which could include tool calls.
 75    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 76    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 77    pub fn send(
 78        &mut self,
 79        model: Arc<dyn LanguageModel>,
 80        content: impl Into<MessageContent>,
 81        cx: &mut Context<Self>,
 82    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 83        cx.notify();
 84        let (events_tx, events_rx) =
 85            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 86
 87        let system_message = self.build_system_message(cx);
 88        self.messages.extend(system_message);
 89
 90        self.messages.push(AgentMessage {
 91            role: Role::User,
 92            content: vec![content.into()],
 93        });
 94        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 95            let turn_result = async {
 96                // Perform one request, then keep looping if the model makes tool calls.
 97                let mut completion_intent = CompletionIntent::UserPrompt;
 98                loop {
 99                    let request = thread.update(cx, |thread, cx| {
100                        thread.build_completion_request(completion_intent, cx)
101                    })?;
102
103                    // println!(
104                    //     "request: {}",
105                    //     serde_json::to_string_pretty(&request).unwrap()
106                    // );
107
108                    // Stream events, appending to messages and collecting up tool uses.
109                    let mut events = model.stream_completion(request, cx).await?;
110                    let mut tool_uses = Vec::new();
111                    while let Some(event) = events.next().await {
112                        match event {
113                            Ok(event) => {
114                                thread
115                                    .update(cx, |thread, cx| {
116                                        tool_uses.extend(thread.handle_streamed_completion_event(
117                                            event,
118                                            events_tx.clone(),
119                                            cx,
120                                        ));
121                                    })
122                                    .ok();
123                            }
124                            Err(error) => {
125                                events_tx.unbounded_send(Err(error)).ok();
126                                break;
127                            }
128                        }
129                    }
130
131                    // If there are no tool uses, the turn is done.
132                    if tool_uses.is_empty() {
133                        break;
134                    }
135
136                    // If there are tool uses, wait for their results to be
137                    // computed, then send them together in a single message on
138                    // the next loop iteration.
139                    let tool_results = future::join_all(tool_uses).await;
140                    thread
141                        .update(cx, |thread, _cx| {
142                            thread.messages.push(AgentMessage {
143                                role: Role::User,
144                                content: tool_results
145                                    .into_iter()
146                                    .map(MessageContent::ToolResult)
147                                    .collect(),
148                            });
149                        })
150                        .ok();
151                    completion_intent = CompletionIntent::ToolResults;
152                }
153
154                Ok(())
155            }
156            .await;
157
158            if let Err(error) = turn_result {
159                events_tx.unbounded_send(Err(error)).ok();
160            }
161        }));
162        events_rx
163    }
164
165    pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
166        let mut system_message = AgentMessage {
167            role: Role::System,
168            content: Vec::new(),
169        };
170
171        for prompt in &self.system_prompts {
172            if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
173                system_message
174                    .content
175                    .push(MessageContent::Text(rendered_prompt));
176            }
177        }
178
179        (!system_message.content.is_empty()).then_some(system_message)
180    }
181
182    /// A helper method that's called on every streamed completion event.
183    /// Returns an optional tool result task, which the main agentic loop in
184    /// send will send back to the model when it resolves.
185    fn handle_streamed_completion_event(
186        &mut self,
187        event: LanguageModelCompletionEvent,
188        events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
189        cx: &mut Context<Self>,
190    ) -> Option<Task<LanguageModelToolResult>> {
191        use LanguageModelCompletionEvent::*;
192        events_tx.unbounded_send(Ok(event.clone())).ok();
193
194        match event {
195            Text(new_text) => self.handle_text_event(new_text, cx),
196            Thinking {
197                text: _text,
198                signature: _signature,
199            } => {
200                todo!()
201            }
202            ToolUse(tool_use) => {
203                return self.handle_tool_use_event(tool_use, cx);
204            }
205            StartMessage { .. } => {
206                self.messages.push(AgentMessage {
207                    role: Role::Assistant,
208                    content: Vec::new(),
209                });
210            }
211            UsageUpdate(_) => {}
212            Stop(stop_reason) => self.handle_stop_event(stop_reason),
213            StatusUpdate(_completion_request_status) => {}
214            RedactedThinking { data: _data } => todo!(),
215            ToolUseJsonParseError {
216                id: _id,
217                tool_name: _tool_name,
218                raw_input: _raw_input,
219                json_parse_error: _json_parse_error,
220            } => todo!(),
221        }
222
223        None
224    }
225
226    fn handle_stop_event(&mut self, stop_reason: StopReason) {
227        match stop_reason {
228            StopReason::EndTurn | StopReason::ToolUse => {}
229            StopReason::MaxTokens => todo!(),
230            StopReason::Refusal => todo!(),
231        }
232    }
233
234    fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
235        let last_message = self.last_assistant_message();
236        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
237            text.push_str(&new_text);
238        } else {
239            last_message.content.push(MessageContent::Text(new_text));
240        }
241
242        cx.notify();
243    }
244
245    fn handle_tool_use_event(
246        &mut self,
247        tool_use: LanguageModelToolUse,
248        cx: &mut Context<Self>,
249    ) -> Option<Task<LanguageModelToolResult>> {
250        cx.notify();
251
252        let last_message = self.last_assistant_message();
253
254        // Ensure the last message ends in the current tool use
255        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
256            if let MessageContent::ToolUse(last_tool_use) = content {
257                if last_tool_use.id == tool_use.id {
258                    *last_tool_use = tool_use.clone();
259                    false
260                } else {
261                    true
262                }
263            } else {
264                true
265            }
266        });
267        if push_new_tool_use {
268            last_message
269                .content
270                .push(MessageContent::ToolUse(tool_use.clone()));
271        }
272
273        if !tool_use.is_input_complete {
274            return None;
275        }
276
277        if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
278            let pending_tool_result = tool.clone().run(tool_use.input, cx);
279
280            Some(cx.foreground_executor().spawn(async move {
281                match pending_tool_result.await {
282                    Ok(tool_output) => LanguageModelToolResult {
283                        tool_use_id: tool_use.id,
284                        tool_name: tool_use.name,
285                        is_error: false,
286                        content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
287                        output: None,
288                    },
289                    Err(error) => LanguageModelToolResult {
290                        tool_use_id: tool_use.id,
291                        tool_name: tool_use.name,
292                        is_error: true,
293                        content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
294                        output: None,
295                    },
296                }
297            }))
298        } else {
299            Some(Task::ready(LanguageModelToolResult {
300                content: LanguageModelToolResultContent::Text(Arc::from(format!(
301                    "No tool named {} exists",
302                    tool_use.name
303                ))),
304                tool_use_id: tool_use.id,
305                tool_name: tool_use.name,
306                is_error: true,
307                output: None,
308            }))
309        }
310    }
311
312    /// Guarantees the last message is from the assistant and returns a mutable reference.
313    fn last_assistant_message(&mut self) -> &mut AgentMessage {
314        if self
315            .messages
316            .last()
317            .map_or(true, |m| m.role != Role::Assistant)
318        {
319            self.messages.push(AgentMessage {
320                role: Role::Assistant,
321                content: Vec::new(),
322            });
323        }
324        self.messages.last_mut().unwrap()
325    }
326
327    fn build_completion_request(
328        &self,
329        completion_intent: CompletionIntent,
330        cx: &mut App,
331    ) -> LanguageModelRequest {
332        LanguageModelRequest {
333            thread_id: None,
334            prompt_id: None,
335            intent: Some(completion_intent),
336            mode: Some(self.completion_mode),
337            messages: self.build_request_messages(),
338            tools: self
339                .tools
340                .values()
341                .filter_map(|tool| {
342                    Some(LanguageModelRequestTool {
343                        name: tool.name().to_string(),
344                        description: tool.description(cx).to_string(),
345                        input_schema: tool
346                            .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
347                            .log_err()?,
348                    })
349                })
350                .collect(),
351            tool_choice: None,
352            stop: Vec::new(),
353            temperature: None,
354            thinking_allowed: false,
355        }
356    }
357
358    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
359        self.messages
360            .iter()
361            .map(|message| LanguageModelRequestMessage {
362                role: message.role,
363                content: message.content.clone(),
364                cache: false,
365            })
366            .collect()
367    }
368}
369
370pub trait AgentTool
371where
372    Self: 'static + Sized,
373{
374    type Input: for<'de> Deserialize<'de> + JsonSchema;
375
376    fn name(&self) -> SharedString;
377    fn description(&self, _cx: &mut App) -> SharedString {
378        let schema = schemars::schema_for!(Self::Input);
379        SharedString::new(
380            schema
381                .get("description")
382                .and_then(|description| description.as_str())
383                .unwrap_or_default(),
384        )
385    }
386
387    /// Returns the JSON schema that describes the tool's input.
388    fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
389        schemars::schema_for!(Self::Input)
390    }
391
392    /// Runs the tool with the provided input.
393    fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
394
395    fn erase(self) -> Arc<dyn AnyAgentTool> {
396        Arc::new(Erased(Arc::new(self)))
397    }
398}
399
400pub struct Erased<T>(T);
401
402pub trait AnyAgentTool {
403    fn name(&self) -> SharedString;
404    fn description(&self, cx: &mut App) -> SharedString;
405    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
406    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
407}
408
409impl<T> AnyAgentTool for Erased<Arc<T>>
410where
411    T: AgentTool,
412{
413    fn name(&self) -> SharedString {
414        self.0.name()
415    }
416
417    fn description(&self, cx: &mut App) -> SharedString {
418        self.0.description(cx)
419    }
420
421    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
422        Ok(serde_json::to_value(self.0.input_schema(format))?)
423    }
424
425    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
426        let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
427        match parsed_input {
428            Ok(input) => self.0.clone().run(input, cx),
429            Err(error) => Task::ready(Err(anyhow!(error))),
430        }
431    }
432}