request.rs

  1use std::io::{Cursor, Write};
  2
  3use crate::role::Role;
  4use base64::write::EncoderWriter;
  5use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
  6use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
  7use serde::{Deserialize, Serialize};
  8use ui::{px, SharedString};
  9use util::ResultExt;
 10
 11#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
 12pub struct LanguageModelImage {
 13    // A base64 encoded PNG image
 14    pub source: SharedString,
 15    size: Size<DevicePixels>,
 16}
 17
 18const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
 19
 20impl LanguageModelImage {
 21    pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
 22        cx.background_executor().spawn(async move {
 23            match data.format() {
 24                gpui::ImageFormat::Png
 25                | gpui::ImageFormat::Jpeg
 26                | gpui::ImageFormat::Webp
 27                | gpui::ImageFormat::Gif => {}
 28                _ => return None,
 29            };
 30
 31            let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
 32            let (width, height) = image.dimensions();
 33            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 34
 35            let mut base64_image = Vec::new();
 36
 37            {
 38                let mut base64_encoder = EncoderWriter::new(
 39                    Cursor::new(&mut base64_image),
 40                    &base64::engine::general_purpose::STANDARD,
 41                );
 42
 43                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 44                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 45                {
 46                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 47                        gpui::Bounds {
 48                            origin: point(px(0.0), px(0.0)),
 49                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 50                        },
 51                        image_size,
 52                    );
 53                    let image = DynamicImage::from_decoder(image).log_err()?.resize(
 54                        new_bounds.size.width.0 as u32,
 55                        new_bounds.size.height.0 as u32,
 56                        image::imageops::FilterType::Triangle,
 57                    );
 58
 59                    let mut png = Vec::new();
 60                    image
 61                        .write_with_encoder(PngEncoder::new(&mut png))
 62                        .log_err()?;
 63
 64                    base64_encoder.write_all(png.as_slice()).log_err()?;
 65                } else {
 66                    base64_encoder.write_all(data.bytes()).log_err()?;
 67                }
 68            }
 69
 70            // SAFETY: The base64 encoder should not produce non-UTF8
 71            let source = unsafe { String::from_utf8_unchecked(base64_image) };
 72
 73            Some(LanguageModelImage {
 74                size: image_size,
 75                source: source.into(),
 76            })
 77        })
 78    }
 79
 80    /// Resolves image into an LLM-ready format (base64)
 81    pub fn from_render_image(data: &RenderImage) -> Option<Self> {
 82        let image_size = data.size(0);
 83
 84        let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
 85        // Convert from BGRA to RGBA.
 86        for pixel in bytes.chunks_exact_mut(4) {
 87            pixel.swap(2, 0);
 88        }
 89        let mut image = image::RgbaImage::from_vec(
 90            image_size.width.0 as u32,
 91            image_size.height.0 as u32,
 92            bytes,
 93        )
 94        .expect("We already know this works");
 95
 96        // https://docs.anthropic.com/en/docs/build-with-claude/vision
 97        if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 98            || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 99        {
100            let new_bounds = ObjectFit::ScaleDown.get_bounds(
101                gpui::Bounds {
102                    origin: point(px(0.0), px(0.0)),
103                    size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
104                },
105                image_size,
106            );
107
108            image = resize(
109                &image,
110                new_bounds.size.width.0 as u32,
111                new_bounds.size.height.0 as u32,
112                image::imageops::FilterType::Triangle,
113            );
114        }
115
116        let mut png = Vec::new();
117
118        image
119            .write_with_encoder(PngEncoder::new(&mut png))
120            .log_err()?;
121
122        let mut base64_image = Vec::new();
123
124        {
125            let mut base64_encoder = EncoderWriter::new(
126                Cursor::new(&mut base64_image),
127                &base64::engine::general_purpose::STANDARD,
128            );
129
130            base64_encoder.write_all(png.as_slice()).log_err()?;
131        }
132
133        // SAFETY: The base64 encoder should not produce non-UTF8
134        let source = unsafe { String::from_utf8_unchecked(base64_image) };
135
136        Some(LanguageModelImage {
137            size: image_size,
138            source: source.into(),
139        })
140    }
141
142    pub fn estimate_tokens(&self) -> usize {
143        let width = self.size.width.0.unsigned_abs() as usize;
144        let height = self.size.height.0.unsigned_abs() as usize;
145
146        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
147        // Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
148        // so this method is more of a rough guess
149        (width * height) / 750
150    }
151}
152
153#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
154pub enum MessageContent {
155    Text(String),
156    Image(LanguageModelImage),
157}
158
159impl std::fmt::Debug for MessageContent {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
163            MessageContent::Image(i) => f
164                .debug_struct("MessageContent")
165                .field("image", &i.source.len())
166                .finish(),
167        }
168    }
169}
170
171impl MessageContent {
172    pub fn as_string(&self) -> &str {
173        match self {
174            MessageContent::Text(s) => s.as_str(),
175            MessageContent::Image(_) => "",
176        }
177    }
178}
179
180impl From<String> for MessageContent {
181    fn from(value: String) -> Self {
182        MessageContent::Text(value)
183    }
184}
185
186impl From<&str> for MessageContent {
187    fn from(value: &str) -> Self {
188        MessageContent::Text(value.to_string())
189    }
190}
191
192#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
193pub struct LanguageModelRequestMessage {
194    pub role: Role,
195    pub content: Vec<MessageContent>,
196    pub cache: bool,
197}
198
199impl LanguageModelRequestMessage {
200    pub fn string_contents(&self) -> String {
201        let mut string_buffer = String::new();
202        for string in self.content.iter().filter_map(|content| match content {
203            MessageContent::Text(s) => Some(s),
204            MessageContent::Image(_) => None,
205        }) {
206            string_buffer.push_str(string.as_str())
207        }
208        string_buffer
209    }
210
211    pub fn contents_empty(&self) -> bool {
212        self.content.is_empty()
213            || self
214                .content
215                .get(0)
216                .map(|content| match content {
217                    MessageContent::Text(s) => s.trim().is_empty(),
218                    MessageContent::Image(_) => true,
219                })
220                .unwrap_or(false)
221    }
222}
223
224#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
225pub struct LanguageModelRequest {
226    pub messages: Vec<LanguageModelRequestMessage>,
227    pub stop: Vec<String>,
228    pub temperature: f32,
229}
230
231impl LanguageModelRequest {
232    pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
233        open_ai::Request {
234            model,
235            messages: self
236                .messages
237                .into_iter()
238                .map(|msg| match msg.role {
239                    Role::User => open_ai::RequestMessage::User {
240                        content: msg.string_contents(),
241                    },
242                    Role::Assistant => open_ai::RequestMessage::Assistant {
243                        content: Some(msg.string_contents()),
244                        tool_calls: Vec::new(),
245                    },
246                    Role::System => open_ai::RequestMessage::System {
247                        content: msg.string_contents(),
248                    },
249                })
250                .collect(),
251            stream: true,
252            stop: self.stop,
253            temperature: self.temperature,
254            max_tokens: max_output_tokens,
255            tools: Vec::new(),
256            tool_choice: None,
257        }
258    }
259
260    pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
261        google_ai::GenerateContentRequest {
262            model,
263            contents: self
264                .messages
265                .into_iter()
266                .map(|msg| google_ai::Content {
267                    parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
268                        text: msg.string_contents(),
269                    })],
270                    role: match msg.role {
271                        Role::User => google_ai::Role::User,
272                        Role::Assistant => google_ai::Role::Model,
273                        Role::System => google_ai::Role::User, // Google AI doesn't have a system role
274                    },
275                })
276                .collect(),
277            generation_config: Some(google_ai::GenerationConfig {
278                candidate_count: Some(1),
279                stop_sequences: Some(self.stop),
280                max_output_tokens: None,
281                temperature: Some(self.temperature as f64),
282                top_p: None,
283                top_k: None,
284            }),
285            safety_settings: None,
286        }
287    }
288
289    pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
290        let mut new_messages: Vec<anthropic::Message> = Vec::new();
291        let mut system_message = String::new();
292
293        for message in self.messages {
294            if message.contents_empty() {
295                continue;
296            }
297
298            match message.role {
299                Role::User | Role::Assistant => {
300                    let cache_control = if message.cache {
301                        Some(anthropic::CacheControl {
302                            cache_type: anthropic::CacheControlType::Ephemeral,
303                        })
304                    } else {
305                        None
306                    };
307                    let anthropic_message_content: Vec<anthropic::Content> = message
308                        .content
309                        .into_iter()
310                        .filter_map(|content| match content {
311                            MessageContent::Text(t) if !t.is_empty() => {
312                                Some(anthropic::Content::Text {
313                                    text: t,
314                                    cache_control,
315                                })
316                            }
317                            MessageContent::Image(i) => Some(anthropic::Content::Image {
318                                source: anthropic::ImageSource {
319                                    source_type: "base64".to_string(),
320                                    media_type: "image/png".to_string(),
321                                    data: i.source.to_string(),
322                                },
323                                cache_control,
324                            }),
325                            _ => None,
326                        })
327                        .collect();
328                    let anthropic_role = match message.role {
329                        Role::User => anthropic::Role::User,
330                        Role::Assistant => anthropic::Role::Assistant,
331                        Role::System => unreachable!("System role should never occur here"),
332                    };
333                    if let Some(last_message) = new_messages.last_mut() {
334                        if last_message.role == anthropic_role {
335                            last_message.content.extend(anthropic_message_content);
336                            continue;
337                        }
338                    }
339                    new_messages.push(anthropic::Message {
340                        role: anthropic_role,
341                        content: anthropic_message_content,
342                    });
343                }
344                Role::System => {
345                    if !system_message.is_empty() {
346                        system_message.push_str("\n\n");
347                    }
348                    system_message.push_str(&message.string_contents());
349                }
350            }
351        }
352
353        anthropic::Request {
354            model,
355            messages: new_messages,
356            max_tokens: max_output_tokens,
357            system: Some(system_message),
358            tools: Vec::new(),
359            tool_choice: None,
360            metadata: None,
361            stop_sequences: Vec::new(),
362            temperature: None,
363            top_k: None,
364            top_p: None,
365        }
366    }
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct LanguageModelResponseMessage {
371    pub role: Option<Role>,
372    pub content: Option<String>,
373}