1use crate::templates::Templates;
2use anyhow::{anyhow, Result};
3use cloud_llm_client::{CompletionIntent, CompletionMode};
4use futures::{channel::mpsc, future};
5use gpui::{App, Context, SharedString, Task};
6use language_model::{
7 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
9 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
10 LanguageModelToolUse, MessageContent, Role, StopReason,
11};
12use schemars::{JsonSchema, Schema};
13use serde::Deserialize;
14use smol::stream::StreamExt;
15use std::{collections::BTreeMap, sync::Arc};
16use util::ResultExt;
17
18#[derive(Debug)]
19pub struct AgentMessage {
20 pub role: Role,
21 pub content: Vec<MessageContent>,
22}
23
24pub type AgentResponseEvent = LanguageModelCompletionEvent;
25
26pub trait Prompt {
27 fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
28}
29
30pub struct Thread {
31 messages: Vec<AgentMessage>,
32 completion_mode: CompletionMode,
33 /// Holds the task that handles agent interaction until the end of the turn.
34 /// Survives across multiple requests as the model performs tool calls and
35 /// we run tools, report their results.
36 running_turn: Option<Task<()>>,
37 system_prompts: Vec<Arc<dyn Prompt>>,
38 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
39 templates: Arc<Templates>,
40 pub selected_model: Arc<dyn LanguageModel>,
41 // project: Entity<Project>,
42 // action_log: Entity<ActionLog>,
43}
44
45impl Thread {
46 pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
47 Self {
48 messages: Vec::new(),
49 completion_mode: CompletionMode::Normal,
50 system_prompts: Vec::new(),
51 running_turn: None,
52 tools: BTreeMap::default(),
53 templates,
54 selected_model: default_model,
55 }
56 }
57
58 pub fn set_mode(&mut self, mode: CompletionMode) {
59 self.completion_mode = mode;
60 }
61
62 pub fn messages(&self) -> &[AgentMessage] {
63 &self.messages
64 }
65
66 pub fn add_tool(&mut self, tool: impl AgentTool) {
67 self.tools.insert(tool.name(), tool.erase());
68 }
69
70 pub fn remove_tool(&mut self, name: &str) -> bool {
71 self.tools.remove(name).is_some()
72 }
73
74 /// Sending a message results in the model streaming a response, which could include tool calls.
75 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
76 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
77 pub fn send(
78 &mut self,
79 model: Arc<dyn LanguageModel>,
80 content: impl Into<MessageContent>,
81 cx: &mut Context<Self>,
82 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
83 cx.notify();
84 let (events_tx, events_rx) =
85 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
86
87 let system_message = self.build_system_message(cx);
88 self.messages.extend(system_message);
89
90 self.messages.push(AgentMessage {
91 role: Role::User,
92 content: vec![content.into()],
93 });
94 self.running_turn = Some(cx.spawn(async move |thread, cx| {
95 let turn_result = async {
96 // Perform one request, then keep looping if the model makes tool calls.
97 let mut completion_intent = CompletionIntent::UserPrompt;
98 loop {
99 let request = thread.update(cx, |thread, cx| {
100 thread.build_completion_request(completion_intent, cx)
101 })?;
102
103 // println!(
104 // "request: {}",
105 // serde_json::to_string_pretty(&request).unwrap()
106 // );
107
108 // Stream events, appending to messages and collecting up tool uses.
109 let mut events = model.stream_completion(request, cx).await?;
110 let mut tool_uses = Vec::new();
111 while let Some(event) = events.next().await {
112 match event {
113 Ok(event) => {
114 thread
115 .update(cx, |thread, cx| {
116 tool_uses.extend(thread.handle_streamed_completion_event(
117 event,
118 events_tx.clone(),
119 cx,
120 ));
121 })
122 .ok();
123 }
124 Err(error) => {
125 events_tx.unbounded_send(Err(error)).ok();
126 break;
127 }
128 }
129 }
130
131 // If there are no tool uses, the turn is done.
132 if tool_uses.is_empty() {
133 break;
134 }
135
136 // If there are tool uses, wait for their results to be
137 // computed, then send them together in a single message on
138 // the next loop iteration.
139 let tool_results = future::join_all(tool_uses).await;
140 thread
141 .update(cx, |thread, _cx| {
142 thread.messages.push(AgentMessage {
143 role: Role::User,
144 content: tool_results
145 .into_iter()
146 .map(MessageContent::ToolResult)
147 .collect(),
148 });
149 })
150 .ok();
151 completion_intent = CompletionIntent::ToolResults;
152 }
153
154 Ok(())
155 }
156 .await;
157
158 if let Err(error) = turn_result {
159 events_tx.unbounded_send(Err(error)).ok();
160 }
161 }));
162 events_rx
163 }
164
165 pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
166 let mut system_message = AgentMessage {
167 role: Role::System,
168 content: Vec::new(),
169 };
170
171 for prompt in &self.system_prompts {
172 if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
173 system_message
174 .content
175 .push(MessageContent::Text(rendered_prompt));
176 }
177 }
178
179 (!system_message.content.is_empty()).then_some(system_message)
180 }
181
182 /// A helper method that's called on every streamed completion event.
183 /// Returns an optional tool result task, which the main agentic loop in
184 /// send will send back to the model when it resolves.
185 fn handle_streamed_completion_event(
186 &mut self,
187 event: LanguageModelCompletionEvent,
188 events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
189 cx: &mut Context<Self>,
190 ) -> Option<Task<LanguageModelToolResult>> {
191 use LanguageModelCompletionEvent::*;
192 events_tx.unbounded_send(Ok(event.clone())).ok();
193
194 match event {
195 Text(new_text) => self.handle_text_event(new_text, cx),
196 Thinking {
197 text: _text,
198 signature: _signature,
199 } => {
200 todo!()
201 }
202 ToolUse(tool_use) => {
203 return self.handle_tool_use_event(tool_use, cx);
204 }
205 StartMessage { .. } => {
206 self.messages.push(AgentMessage {
207 role: Role::Assistant,
208 content: Vec::new(),
209 });
210 }
211 UsageUpdate(_) => {}
212 Stop(stop_reason) => self.handle_stop_event(stop_reason),
213 StatusUpdate(_completion_request_status) => {}
214 RedactedThinking { data: _data } => todo!(),
215 ToolUseJsonParseError {
216 id: _id,
217 tool_name: _tool_name,
218 raw_input: _raw_input,
219 json_parse_error: _json_parse_error,
220 } => todo!(),
221 }
222
223 None
224 }
225
226 fn handle_stop_event(&mut self, stop_reason: StopReason) {
227 match stop_reason {
228 StopReason::EndTurn | StopReason::ToolUse => {}
229 StopReason::MaxTokens => todo!(),
230 StopReason::Refusal => todo!(),
231 }
232 }
233
234 fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
235 let last_message = self.last_assistant_message();
236 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
237 text.push_str(&new_text);
238 } else {
239 last_message.content.push(MessageContent::Text(new_text));
240 }
241
242 cx.notify();
243 }
244
245 fn handle_tool_use_event(
246 &mut self,
247 tool_use: LanguageModelToolUse,
248 cx: &mut Context<Self>,
249 ) -> Option<Task<LanguageModelToolResult>> {
250 cx.notify();
251
252 let last_message = self.last_assistant_message();
253
254 // Ensure the last message ends in the current tool use
255 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
256 if let MessageContent::ToolUse(last_tool_use) = content {
257 if last_tool_use.id == tool_use.id {
258 *last_tool_use = tool_use.clone();
259 false
260 } else {
261 true
262 }
263 } else {
264 true
265 }
266 });
267 if push_new_tool_use {
268 last_message
269 .content
270 .push(MessageContent::ToolUse(tool_use.clone()));
271 }
272
273 if !tool_use.is_input_complete {
274 return None;
275 }
276
277 if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
278 let pending_tool_result = tool.clone().run(tool_use.input, cx);
279
280 Some(cx.foreground_executor().spawn(async move {
281 match pending_tool_result.await {
282 Ok(tool_output) => LanguageModelToolResult {
283 tool_use_id: tool_use.id,
284 tool_name: tool_use.name,
285 is_error: false,
286 content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
287 output: None,
288 },
289 Err(error) => LanguageModelToolResult {
290 tool_use_id: tool_use.id,
291 tool_name: tool_use.name,
292 is_error: true,
293 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
294 output: None,
295 },
296 }
297 }))
298 } else {
299 Some(Task::ready(LanguageModelToolResult {
300 content: LanguageModelToolResultContent::Text(Arc::from(format!(
301 "No tool named {} exists",
302 tool_use.name
303 ))),
304 tool_use_id: tool_use.id,
305 tool_name: tool_use.name,
306 is_error: true,
307 output: None,
308 }))
309 }
310 }
311
312 /// Guarantees the last message is from the assistant and returns a mutable reference.
313 fn last_assistant_message(&mut self) -> &mut AgentMessage {
314 if self
315 .messages
316 .last()
317 .map_or(true, |m| m.role != Role::Assistant)
318 {
319 self.messages.push(AgentMessage {
320 role: Role::Assistant,
321 content: Vec::new(),
322 });
323 }
324 self.messages.last_mut().unwrap()
325 }
326
327 fn build_completion_request(
328 &self,
329 completion_intent: CompletionIntent,
330 cx: &mut App,
331 ) -> LanguageModelRequest {
332 LanguageModelRequest {
333 thread_id: None,
334 prompt_id: None,
335 intent: Some(completion_intent),
336 mode: Some(self.completion_mode),
337 messages: self.build_request_messages(),
338 tools: self
339 .tools
340 .values()
341 .filter_map(|tool| {
342 Some(LanguageModelRequestTool {
343 name: tool.name().to_string(),
344 description: tool.description(cx).to_string(),
345 input_schema: tool
346 .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
347 .log_err()?,
348 })
349 })
350 .collect(),
351 tool_choice: None,
352 stop: Vec::new(),
353 temperature: None,
354 thinking_allowed: false,
355 }
356 }
357
358 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
359 self.messages
360 .iter()
361 .map(|message| LanguageModelRequestMessage {
362 role: message.role,
363 content: message.content.clone(),
364 cache: false,
365 })
366 .collect()
367 }
368}
369
370pub trait AgentTool
371where
372 Self: 'static + Sized,
373{
374 type Input: for<'de> Deserialize<'de> + JsonSchema;
375
376 fn name(&self) -> SharedString;
377 fn description(&self, _cx: &mut App) -> SharedString {
378 let schema = schemars::schema_for!(Self::Input);
379 SharedString::new(
380 schema
381 .get("description")
382 .and_then(|description| description.as_str())
383 .unwrap_or_default(),
384 )
385 }
386
387 /// Returns the JSON schema that describes the tool's input.
388 fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
389 schemars::schema_for!(Self::Input)
390 }
391
392 /// Runs the tool with the provided input.
393 fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
394
395 fn erase(self) -> Arc<dyn AnyAgentTool> {
396 Arc::new(Erased(Arc::new(self)))
397 }
398}
399
400pub struct Erased<T>(T);
401
402pub trait AnyAgentTool {
403 fn name(&self) -> SharedString;
404 fn description(&self, cx: &mut App) -> SharedString;
405 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
406 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
407}
408
409impl<T> AnyAgentTool for Erased<Arc<T>>
410where
411 T: AgentTool,
412{
413 fn name(&self) -> SharedString {
414 self.0.name()
415 }
416
417 fn description(&self, cx: &mut App) -> SharedString {
418 self.0.description(cx)
419 }
420
421 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
422 Ok(serde_json::to_value(self.0.input_schema(format))?)
423 }
424
425 fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
426 let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
427 match parsed_input {
428 Ok(input) => self.0.clone().run(input, cx),
429 Err(error) => Task::ready(Err(anyhow!(error))),
430 }
431 }
432}