request.rs

  1use std::io::{Cursor, Write};
  2
  3use crate::role::Role;
  4use crate::LanguageModelToolUse;
  5use base64::write::EncoderWriter;
  6use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
  7use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
  8use serde::{Deserialize, Serialize};
  9use ui::{px, SharedString};
 10use util::ResultExt;
 11
 12#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 13pub struct LanguageModelImage {
 14    /// A base64-encoded PNG image.
 15    pub source: SharedString,
 16    size: Size<DevicePixels>,
 17}
 18
 19impl std::fmt::Debug for LanguageModelImage {
 20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 21        f.debug_struct("LanguageModelImage")
 22            .field("source", &format!("<{} bytes>", self.source.len()))
 23            .field("size", &self.size)
 24            .finish()
 25    }
 26}
 27
 28/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 29const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 30
 31impl LanguageModelImage {
 32    pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
 33        cx.background_executor().spawn(async move {
 34            match data.format() {
 35                gpui::ImageFormat::Png
 36                | gpui::ImageFormat::Jpeg
 37                | gpui::ImageFormat::Webp
 38                | gpui::ImageFormat::Gif => {}
 39                _ => return None,
 40            };
 41
 42            let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
 43            let (width, height) = image.dimensions();
 44            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 45
 46            let mut base64_image = Vec::new();
 47
 48            {
 49                let mut base64_encoder = EncoderWriter::new(
 50                    Cursor::new(&mut base64_image),
 51                    &base64::engine::general_purpose::STANDARD,
 52                );
 53
 54                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 55                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 56                {
 57                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 58                        gpui::Bounds {
 59                            origin: point(px(0.0), px(0.0)),
 60                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 61                        },
 62                        image_size,
 63                    );
 64                    let image = DynamicImage::from_decoder(image).log_err()?.resize(
 65                        new_bounds.size.width.0 as u32,
 66                        new_bounds.size.height.0 as u32,
 67                        image::imageops::FilterType::Triangle,
 68                    );
 69
 70                    let mut png = Vec::new();
 71                    image
 72                        .write_with_encoder(PngEncoder::new(&mut png))
 73                        .log_err()?;
 74
 75                    base64_encoder.write_all(png.as_slice()).log_err()?;
 76                } else {
 77                    base64_encoder.write_all(data.bytes()).log_err()?;
 78                }
 79            }
 80
 81            // SAFETY: The base64 encoder should not produce non-UTF8.
 82            let source = unsafe { String::from_utf8_unchecked(base64_image) };
 83
 84            Some(LanguageModelImage {
 85                size: image_size,
 86                source: source.into(),
 87            })
 88        })
 89    }
 90
 91    /// Resolves image into an LLM-ready format (base64).
 92    pub fn from_render_image(data: &RenderImage) -> Option<Self> {
 93        let image_size = data.size(0);
 94
 95        let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
 96        // Convert from BGRA to RGBA.
 97        for pixel in bytes.chunks_exact_mut(4) {
 98            pixel.swap(2, 0);
 99        }
100        let mut image = image::RgbaImage::from_vec(
101            image_size.width.0 as u32,
102            image_size.height.0 as u32,
103            bytes,
104        )
105        .expect("We already know this works");
106
107        // https://docs.anthropic.com/en/docs/build-with-claude/vision
108        if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
109            || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
110        {
111            let new_bounds = ObjectFit::ScaleDown.get_bounds(
112                gpui::Bounds {
113                    origin: point(px(0.0), px(0.0)),
114                    size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
115                },
116                image_size,
117            );
118
119            image = resize(
120                &image,
121                new_bounds.size.width.0 as u32,
122                new_bounds.size.height.0 as u32,
123                image::imageops::FilterType::Triangle,
124            );
125        }
126
127        let mut png = Vec::new();
128
129        image
130            .write_with_encoder(PngEncoder::new(&mut png))
131            .log_err()?;
132
133        let mut base64_image = Vec::new();
134
135        {
136            let mut base64_encoder = EncoderWriter::new(
137                Cursor::new(&mut base64_image),
138                &base64::engine::general_purpose::STANDARD,
139            );
140
141            base64_encoder.write_all(png.as_slice()).log_err()?;
142        }
143
144        // SAFETY: The base64 encoder should not produce non-UTF8.
145        let source = unsafe { String::from_utf8_unchecked(base64_image) };
146
147        Some(LanguageModelImage {
148            size: image_size,
149            source: source.into(),
150        })
151    }
152
153    pub fn estimate_tokens(&self) -> usize {
154        let width = self.size.width.0.unsigned_abs() as usize;
155        let height = self.size.height.0.unsigned_abs() as usize;
156
157        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
158        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
159        // so this method is more of a rough guess.
160        (width * height) / 750
161    }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
165pub struct LanguageModelToolResult {
166    pub tool_use_id: String,
167    pub is_error: bool,
168    pub content: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
172pub enum MessageContent {
173    Text(String),
174    Image(LanguageModelImage),
175    ToolUse(LanguageModelToolUse),
176    ToolResult(LanguageModelToolResult),
177}
178
179impl From<String> for MessageContent {
180    fn from(value: String) -> Self {
181        MessageContent::Text(value)
182    }
183}
184
185impl From<&str> for MessageContent {
186    fn from(value: &str) -> Self {
187        MessageContent::Text(value.to_string())
188    }
189}
190
191#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
192pub struct LanguageModelRequestMessage {
193    pub role: Role,
194    pub content: Vec<MessageContent>,
195    pub cache: bool,
196}
197
198impl LanguageModelRequestMessage {
199    pub fn string_contents(&self) -> String {
200        let mut string_buffer = String::new();
201        for string in self.content.iter().filter_map(|content| match content {
202            MessageContent::Text(text) => Some(text),
203            MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
204            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
205        }) {
206            string_buffer.push_str(string.as_str())
207        }
208        string_buffer
209    }
210
211    pub fn contents_empty(&self) -> bool {
212        self.content.is_empty()
213            || self
214                .content
215                .first()
216                .map(|content| match content {
217                    MessageContent::Text(text) => text.trim().is_empty(),
218                    MessageContent::ToolResult(tool_result) => {
219                        tool_result.content.trim().is_empty()
220                    }
221                    MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
222                })
223                .unwrap_or(false)
224    }
225}
226
227#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
228pub struct LanguageModelRequestTool {
229    pub name: String,
230    pub description: String,
231    pub input_schema: serde_json::Value,
232}
233
234#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
235pub struct LanguageModelRequest {
236    pub messages: Vec<LanguageModelRequestMessage>,
237    pub tools: Vec<LanguageModelRequestTool>,
238    pub stop: Vec<String>,
239    pub temperature: f32,
240}
241
242impl LanguageModelRequest {
243    pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
244        open_ai::Request {
245            model,
246            messages: self
247                .messages
248                .into_iter()
249                .map(|msg| match msg.role {
250                    Role::User => open_ai::RequestMessage::User {
251                        content: msg.string_contents(),
252                    },
253                    Role::Assistant => open_ai::RequestMessage::Assistant {
254                        content: Some(msg.string_contents()),
255                        tool_calls: Vec::new(),
256                    },
257                    Role::System => open_ai::RequestMessage::System {
258                        content: msg.string_contents(),
259                    },
260                })
261                .collect(),
262            stream: true,
263            stop: self.stop,
264            temperature: self.temperature,
265            max_tokens: max_output_tokens,
266            tools: Vec::new(),
267            tool_choice: None,
268        }
269    }
270
271    pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
272        google_ai::GenerateContentRequest {
273            model,
274            contents: self
275                .messages
276                .into_iter()
277                .map(|msg| google_ai::Content {
278                    parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
279                        text: msg.string_contents(),
280                    })],
281                    role: match msg.role {
282                        Role::User => google_ai::Role::User,
283                        Role::Assistant => google_ai::Role::Model,
284                        Role::System => google_ai::Role::User, // Google AI doesn't have a system role
285                    },
286                })
287                .collect(),
288            generation_config: Some(google_ai::GenerationConfig {
289                candidate_count: Some(1),
290                stop_sequences: Some(self.stop),
291                max_output_tokens: None,
292                temperature: Some(self.temperature as f64),
293                top_p: None,
294                top_k: None,
295            }),
296            safety_settings: None,
297        }
298    }
299
300    pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
301        let mut new_messages: Vec<anthropic::Message> = Vec::new();
302        let mut system_message = String::new();
303
304        for message in self.messages {
305            if message.contents_empty() {
306                continue;
307            }
308
309            match message.role {
310                Role::User | Role::Assistant => {
311                    let cache_control = if message.cache {
312                        Some(anthropic::CacheControl {
313                            cache_type: anthropic::CacheControlType::Ephemeral,
314                        })
315                    } else {
316                        None
317                    };
318                    let anthropic_message_content: Vec<anthropic::RequestContent> = message
319                        .content
320                        .into_iter()
321                        .filter_map(|content| match content {
322                            MessageContent::Text(text) => {
323                                if !text.is_empty() {
324                                    Some(anthropic::RequestContent::Text {
325                                        text,
326                                        cache_control,
327                                    })
328                                } else {
329                                    None
330                                }
331                            }
332                            MessageContent::Image(image) => {
333                                Some(anthropic::RequestContent::Image {
334                                    source: anthropic::ImageSource {
335                                        source_type: "base64".to_string(),
336                                        media_type: "image/png".to_string(),
337                                        data: image.source.to_string(),
338                                    },
339                                    cache_control,
340                                })
341                            }
342                            MessageContent::ToolUse(tool_use) => {
343                                Some(anthropic::RequestContent::ToolUse {
344                                    id: tool_use.id,
345                                    name: tool_use.name,
346                                    input: tool_use.input,
347                                    cache_control,
348                                })
349                            }
350                            MessageContent::ToolResult(tool_result) => {
351                                Some(anthropic::RequestContent::ToolResult {
352                                    tool_use_id: tool_result.tool_use_id,
353                                    is_error: tool_result.is_error,
354                                    content: tool_result.content,
355                                    cache_control,
356                                })
357                            }
358                        })
359                        .collect();
360                    let anthropic_role = match message.role {
361                        Role::User => anthropic::Role::User,
362                        Role::Assistant => anthropic::Role::Assistant,
363                        Role::System => unreachable!("System role should never occur here"),
364                    };
365                    if let Some(last_message) = new_messages.last_mut() {
366                        if last_message.role == anthropic_role {
367                            last_message.content.extend(anthropic_message_content);
368                            continue;
369                        }
370                    }
371                    new_messages.push(anthropic::Message {
372                        role: anthropic_role,
373                        content: anthropic_message_content,
374                    });
375                }
376                Role::System => {
377                    if !system_message.is_empty() {
378                        system_message.push_str("\n\n");
379                    }
380                    system_message.push_str(&message.string_contents());
381                }
382            }
383        }
384
385        anthropic::Request {
386            model,
387            messages: new_messages,
388            max_tokens: max_output_tokens,
389            system: Some(system_message),
390            tools: self
391                .tools
392                .into_iter()
393                .map(|tool| anthropic::Tool {
394                    name: tool.name,
395                    description: tool.description,
396                    input_schema: tool.input_schema,
397                })
398                .collect(),
399            tool_choice: None,
400            metadata: None,
401            stop_sequences: Vec::new(),
402            temperature: None,
403            top_k: None,
404            top_p: None,
405        }
406    }
407}
408
409#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
410pub struct LanguageModelResponseMessage {
411    pub role: Option<Role>,
412    pub content: Option<String>,
413}