request.rs

  1use std::io::{Cursor, Write};
  2
  3use crate::role::Role;
  4use crate::LanguageModelToolUse;
  5use base64::write::EncoderWriter;
  6use gpui::{
  7    point, size, App, AppContext as _, DevicePixels, Image, ObjectFit, RenderImage, Size, Task,
  8};
  9use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
 10use serde::{Deserialize, Serialize};
 11use ui::{px, SharedString};
 12use util::ResultExt;
 13
 14#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 15pub struct LanguageModelImage {
 16    /// A base64-encoded PNG image.
 17    pub source: SharedString,
 18    size: Size<DevicePixels>,
 19}
 20
 21impl std::fmt::Debug for LanguageModelImage {
 22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 23        f.debug_struct("LanguageModelImage")
 24            .field("source", &format!("<{} bytes>", self.source.len()))
 25            .field("size", &self.size)
 26            .finish()
 27    }
 28}
 29
 30/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 31const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 32
 33impl LanguageModelImage {
 34    pub fn from_image(data: Image, cx: &mut App) -> Task<Option<Self>> {
 35        cx.background_spawn(async move {
 36            match data.format() {
 37                gpui::ImageFormat::Png
 38                | gpui::ImageFormat::Jpeg
 39                | gpui::ImageFormat::Webp
 40                | gpui::ImageFormat::Gif => {}
 41                _ => return None,
 42            };
 43
 44            let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
 45            let (width, height) = image.dimensions();
 46            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 47
 48            let mut base64_image = Vec::new();
 49
 50            {
 51                let mut base64_encoder = EncoderWriter::new(
 52                    Cursor::new(&mut base64_image),
 53                    &base64::engine::general_purpose::STANDARD,
 54                );
 55
 56                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 57                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 58                {
 59                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 60                        gpui::Bounds {
 61                            origin: point(px(0.0), px(0.0)),
 62                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 63                        },
 64                        image_size,
 65                    );
 66                    let image = DynamicImage::from_decoder(image).log_err()?.resize(
 67                        new_bounds.size.width.0 as u32,
 68                        new_bounds.size.height.0 as u32,
 69                        image::imageops::FilterType::Triangle,
 70                    );
 71
 72                    let mut png = Vec::new();
 73                    image
 74                        .write_with_encoder(PngEncoder::new(&mut png))
 75                        .log_err()?;
 76
 77                    base64_encoder.write_all(png.as_slice()).log_err()?;
 78                } else {
 79                    base64_encoder.write_all(data.bytes()).log_err()?;
 80                }
 81            }
 82
 83            // SAFETY: The base64 encoder should not produce non-UTF8.
 84            let source = unsafe { String::from_utf8_unchecked(base64_image) };
 85
 86            Some(LanguageModelImage {
 87                size: image_size,
 88                source: source.into(),
 89            })
 90        })
 91    }
 92
 93    /// Resolves image into an LLM-ready format (base64).
 94    pub fn from_render_image(data: &RenderImage) -> Option<Self> {
 95        let image_size = data.size(0);
 96
 97        let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
 98        // Convert from BGRA to RGBA.
 99        for pixel in bytes.chunks_exact_mut(4) {
100            pixel.swap(2, 0);
101        }
102        let mut image = image::RgbaImage::from_vec(
103            image_size.width.0 as u32,
104            image_size.height.0 as u32,
105            bytes,
106        )
107        .expect("We already know this works");
108
109        // https://docs.anthropic.com/en/docs/build-with-claude/vision
110        if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
111            || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
112        {
113            let new_bounds = ObjectFit::ScaleDown.get_bounds(
114                gpui::Bounds {
115                    origin: point(px(0.0), px(0.0)),
116                    size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
117                },
118                image_size,
119            );
120
121            image = resize(
122                &image,
123                new_bounds.size.width.0 as u32,
124                new_bounds.size.height.0 as u32,
125                image::imageops::FilterType::Triangle,
126            );
127        }
128
129        let mut png = Vec::new();
130
131        image
132            .write_with_encoder(PngEncoder::new(&mut png))
133            .log_err()?;
134
135        let mut base64_image = Vec::new();
136
137        {
138            let mut base64_encoder = EncoderWriter::new(
139                Cursor::new(&mut base64_image),
140                &base64::engine::general_purpose::STANDARD,
141            );
142
143            base64_encoder.write_all(png.as_slice()).log_err()?;
144        }
145
146        // SAFETY: The base64 encoder should not produce non-UTF8.
147        let source = unsafe { String::from_utf8_unchecked(base64_image) };
148
149        Some(LanguageModelImage {
150            size: image_size,
151            source: source.into(),
152        })
153    }
154
155    pub fn estimate_tokens(&self) -> usize {
156        let width = self.size.width.0.unsigned_abs() as usize;
157        let height = self.size.height.0.unsigned_abs() as usize;
158
159        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
160        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
161        // so this method is more of a rough guess.
162        (width * height) / 750
163    }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
167pub struct LanguageModelToolResult {
168    pub tool_use_id: String,
169    pub is_error: bool,
170    pub content: String,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
174pub enum MessageContent {
175    Text(String),
176    Image(LanguageModelImage),
177    ToolUse(LanguageModelToolUse),
178    ToolResult(LanguageModelToolResult),
179}
180
181impl From<String> for MessageContent {
182    fn from(value: String) -> Self {
183        MessageContent::Text(value)
184    }
185}
186
187impl From<&str> for MessageContent {
188    fn from(value: &str) -> Self {
189        MessageContent::Text(value.to_string())
190    }
191}
192
193#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
194pub struct LanguageModelRequestMessage {
195    pub role: Role,
196    pub content: Vec<MessageContent>,
197    pub cache: bool,
198}
199
200impl LanguageModelRequestMessage {
201    pub fn string_contents(&self) -> String {
202        let mut string_buffer = String::new();
203        for string in self.content.iter().filter_map(|content| match content {
204            MessageContent::Text(text) => Some(text),
205            MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
206            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
207        }) {
208            string_buffer.push_str(string.as_str())
209        }
210        string_buffer
211    }
212
213    pub fn contents_empty(&self) -> bool {
214        self.content.is_empty()
215            || self
216                .content
217                .first()
218                .map(|content| match content {
219                    MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
220                    MessageContent::ToolResult(tool_result) => {
221                        tool_result.content.chars().all(|c| c.is_whitespace())
222                    }
223                    MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
224                })
225                .unwrap_or(false)
226    }
227}
228
229#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
230pub struct LanguageModelRequestTool {
231    pub name: String,
232    pub description: String,
233    pub input_schema: serde_json::Value,
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
237pub struct LanguageModelRequest {
238    pub messages: Vec<LanguageModelRequestMessage>,
239    pub tools: Vec<LanguageModelRequestTool>,
240    pub stop: Vec<String>,
241    pub temperature: Option<f32>,
242}
243
244impl LanguageModelRequest {
245    pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
246        let stream = !model.starts_with("o1-");
247        open_ai::Request {
248            model,
249            messages: self
250                .messages
251                .into_iter()
252                .map(|msg| match msg.role {
253                    Role::User => open_ai::RequestMessage::User {
254                        content: msg.string_contents(),
255                    },
256                    Role::Assistant => open_ai::RequestMessage::Assistant {
257                        content: Some(msg.string_contents()),
258                        tool_calls: Vec::new(),
259                    },
260                    Role::System => open_ai::RequestMessage::System {
261                        content: msg.string_contents(),
262                    },
263                })
264                .collect(),
265            stream,
266            stop: self.stop,
267            temperature: self.temperature.unwrap_or(1.0),
268            max_tokens: max_output_tokens,
269            tools: Vec::new(),
270            tool_choice: None,
271        }
272    }
273
274    pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
275        let len = self.messages.len();
276        let merged_messages =
277            self.messages
278                .into_iter()
279                .fold(Vec::with_capacity(len), |mut acc, msg| {
280                    let role = msg.role;
281                    let content = msg.string_contents();
282
283                    acc.push(match role {
284                        Role::User => mistral::RequestMessage::User { content },
285                        Role::Assistant => mistral::RequestMessage::Assistant {
286                            content: Some(content),
287                            tool_calls: Vec::new(),
288                        },
289                        Role::System => mistral::RequestMessage::System { content },
290                    });
291                    acc
292                });
293
294        mistral::Request {
295            model,
296            messages: merged_messages,
297            stream: true,
298            max_tokens: max_output_tokens,
299            temperature: self.temperature,
300            response_format: None,
301            tools: self
302                .tools
303                .into_iter()
304                .map(|tool| mistral::ToolDefinition::Function {
305                    function: mistral::FunctionDefinition {
306                        name: tool.name,
307                        description: Some(tool.description),
308                        parameters: Some(tool.input_schema),
309                    },
310                })
311                .collect(),
312        }
313    }
314
315    pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
316        google_ai::GenerateContentRequest {
317            model,
318            contents: self
319                .messages
320                .into_iter()
321                .map(|msg| google_ai::Content {
322                    parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
323                        text: msg.string_contents(),
324                    })],
325                    role: match msg.role {
326                        Role::User => google_ai::Role::User,
327                        Role::Assistant => google_ai::Role::Model,
328                        Role::System => google_ai::Role::User, // Google AI doesn't have a system role
329                    },
330                })
331                .collect(),
332            generation_config: Some(google_ai::GenerationConfig {
333                candidate_count: Some(1),
334                stop_sequences: Some(self.stop),
335                max_output_tokens: None,
336                temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
337                top_p: None,
338                top_k: None,
339            }),
340            safety_settings: None,
341        }
342    }
343
344    pub fn into_anthropic(
345        self,
346        model: String,
347        default_temperature: f32,
348        max_output_tokens: u32,
349    ) -> anthropic::Request {
350        let mut new_messages: Vec<anthropic::Message> = Vec::new();
351        let mut system_message = String::new();
352
353        for message in self.messages {
354            if message.contents_empty() {
355                continue;
356            }
357
358            match message.role {
359                Role::User | Role::Assistant => {
360                    let cache_control = if message.cache {
361                        Some(anthropic::CacheControl {
362                            cache_type: anthropic::CacheControlType::Ephemeral,
363                        })
364                    } else {
365                        None
366                    };
367                    let anthropic_message_content: Vec<anthropic::RequestContent> = message
368                        .content
369                        .into_iter()
370                        .filter_map(|content| match content {
371                            MessageContent::Text(text) => {
372                                if !text.is_empty() {
373                                    Some(anthropic::RequestContent::Text {
374                                        text,
375                                        cache_control,
376                                    })
377                                } else {
378                                    None
379                                }
380                            }
381                            MessageContent::Image(image) => {
382                                Some(anthropic::RequestContent::Image {
383                                    source: anthropic::ImageSource {
384                                        source_type: "base64".to_string(),
385                                        media_type: "image/png".to_string(),
386                                        data: image.source.to_string(),
387                                    },
388                                    cache_control,
389                                })
390                            }
391                            MessageContent::ToolUse(tool_use) => {
392                                Some(anthropic::RequestContent::ToolUse {
393                                    id: tool_use.id.to_string(),
394                                    name: tool_use.name,
395                                    input: tool_use.input,
396                                    cache_control,
397                                })
398                            }
399                            MessageContent::ToolResult(tool_result) => {
400                                Some(anthropic::RequestContent::ToolResult {
401                                    tool_use_id: tool_result.tool_use_id,
402                                    is_error: tool_result.is_error,
403                                    content: tool_result.content,
404                                    cache_control,
405                                })
406                            }
407                        })
408                        .collect();
409                    let anthropic_role = match message.role {
410                        Role::User => anthropic::Role::User,
411                        Role::Assistant => anthropic::Role::Assistant,
412                        Role::System => unreachable!("System role should never occur here"),
413                    };
414                    if let Some(last_message) = new_messages.last_mut() {
415                        if last_message.role == anthropic_role {
416                            last_message.content.extend(anthropic_message_content);
417                            continue;
418                        }
419                    }
420                    new_messages.push(anthropic::Message {
421                        role: anthropic_role,
422                        content: anthropic_message_content,
423                    });
424                }
425                Role::System => {
426                    if !system_message.is_empty() {
427                        system_message.push_str("\n\n");
428                    }
429                    system_message.push_str(&message.string_contents());
430                }
431            }
432        }
433
434        anthropic::Request {
435            model,
436            messages: new_messages,
437            max_tokens: max_output_tokens,
438            system: Some(system_message),
439            tools: self
440                .tools
441                .into_iter()
442                .map(|tool| anthropic::Tool {
443                    name: tool.name,
444                    description: tool.description,
445                    input_schema: tool.input_schema,
446                })
447                .collect(),
448            tool_choice: None,
449            metadata: None,
450            stop_sequences: Vec::new(),
451            temperature: self.temperature.or(Some(default_temperature)),
452            top_k: None,
453            top_p: None,
454        }
455    }
456
457    pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
458        let is_reasoner = model == "deepseek-reasoner";
459
460        let len = self.messages.len();
461        let merged_messages =
462            self.messages
463                .into_iter()
464                .fold(Vec::with_capacity(len), |mut acc, msg| {
465                    let role = msg.role;
466                    let content = msg.string_contents();
467
468                    if is_reasoner {
469                        if let Some(last_msg) = acc.last_mut() {
470                            match (last_msg, role) {
471                                (deepseek::RequestMessage::User { content: last }, Role::User) => {
472                                    last.push(' ');
473                                    last.push_str(&content);
474                                    return acc;
475                                }
476
477                                (
478                                    deepseek::RequestMessage::Assistant {
479                                        content: last_content,
480                                        ..
481                                    },
482                                    Role::Assistant,
483                                ) => {
484                                    *last_content = last_content
485                                        .take()
486                                        .map(|c| {
487                                            let mut s =
488                                                String::with_capacity(c.len() + content.len() + 1);
489                                            s.push_str(&c);
490                                            s.push(' ');
491                                            s.push_str(&content);
492                                            s
493                                        })
494                                        .or(Some(content));
495
496                                    return acc;
497                                }
498                                _ => {}
499                            }
500                        }
501                    }
502
503                    acc.push(match role {
504                        Role::User => deepseek::RequestMessage::User { content },
505                        Role::Assistant => deepseek::RequestMessage::Assistant {
506                            content: Some(content),
507                            tool_calls: Vec::new(),
508                        },
509                        Role::System => deepseek::RequestMessage::System { content },
510                    });
511                    acc
512                });
513
514        deepseek::Request {
515            model,
516            messages: merged_messages,
517            stream: true,
518            max_tokens: max_output_tokens,
519            temperature: if is_reasoner { None } else { self.temperature },
520            response_format: None,
521            tools: self
522                .tools
523                .into_iter()
524                .map(|tool| deepseek::ToolDefinition::Function {
525                    function: deepseek::FunctionDefinition {
526                        name: tool.name,
527                        description: Some(tool.description),
528                        parameters: Some(tool.input_schema),
529                    },
530                })
531                .collect(),
532        }
533    }
534}
535
536#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
537pub struct LanguageModelResponseMessage {
538    pub role: Option<Role>,
539    pub content: Option<String>,
540}