request.rs

  1use std::io::{Cursor, Write};
  2
  3use crate::role::Role;
  4use base64::write::EncoderWriter;
  5use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
  6use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
  7use serde::{Deserialize, Serialize};
  8use ui::{px, SharedString};
  9use util::ResultExt;
 10
 11#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
 12pub struct LanguageModelImage {
 13    // A base64 encoded PNG image
 14    pub source: SharedString,
 15    size: Size<DevicePixels>,
 16}
 17
 18const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
 19
 20impl LanguageModelImage {
 21    pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
 22        cx.background_executor().spawn(async move {
 23            match data.format() {
 24                gpui::ImageFormat::Png
 25                | gpui::ImageFormat::Jpeg
 26                | gpui::ImageFormat::Webp
 27                | gpui::ImageFormat::Gif => {}
 28                _ => return None,
 29            };
 30
 31            let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
 32            let (width, height) = image.dimensions();
 33            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 34
 35            let mut base64_image = Vec::new();
 36
 37            {
 38                let mut base64_encoder = EncoderWriter::new(
 39                    Cursor::new(&mut base64_image),
 40                    &base64::engine::general_purpose::STANDARD,
 41                );
 42
 43                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 44                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 45                {
 46                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 47                        gpui::Bounds {
 48                            origin: point(px(0.0), px(0.0)),
 49                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 50                        },
 51                        image_size,
 52                    );
 53                    let image = DynamicImage::from_decoder(image).log_err()?.resize(
 54                        new_bounds.size.width.0 as u32,
 55                        new_bounds.size.height.0 as u32,
 56                        image::imageops::FilterType::Triangle,
 57                    );
 58
 59                    let mut png = Vec::new();
 60                    image
 61                        .write_with_encoder(PngEncoder::new(&mut png))
 62                        .log_err()?;
 63
 64                    base64_encoder.write_all(png.as_slice()).log_err()?;
 65                } else {
 66                    base64_encoder.write_all(data.bytes()).log_err()?;
 67                }
 68            }
 69
 70            // SAFETY: The base64 encoder should not produce non-UTF8
 71            let source = unsafe { String::from_utf8_unchecked(base64_image) };
 72
 73            Some(LanguageModelImage {
 74                size: image_size,
 75                source: source.into(),
 76            })
 77        })
 78    }
 79
 80    /// Resolves image into an LLM-ready format (base64)
 81    pub fn from_render_image(data: &RenderImage) -> Option<Self> {
 82        let image_size = data.size(0);
 83
 84        let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
 85        // Convert from BGRA to RGBA.
 86        for pixel in bytes.chunks_exact_mut(4) {
 87            pixel.swap(2, 0);
 88        }
 89        let mut image = image::RgbaImage::from_vec(
 90            image_size.width.0 as u32,
 91            image_size.height.0 as u32,
 92            bytes,
 93        )
 94        .expect("We already know this works");
 95
 96        // https://docs.anthropic.com/en/docs/build-with-claude/vision
 97        if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 98            || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 99        {
100            let new_bounds = ObjectFit::ScaleDown.get_bounds(
101                gpui::Bounds {
102                    origin: point(px(0.0), px(0.0)),
103                    size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
104                },
105                image_size,
106            );
107
108            image = resize(
109                &image,
110                new_bounds.size.width.0 as u32,
111                new_bounds.size.height.0 as u32,
112                image::imageops::FilterType::Triangle,
113            );
114        }
115
116        let mut png = Vec::new();
117
118        image
119            .write_with_encoder(PngEncoder::new(&mut png))
120            .log_err()?;
121
122        let mut base64_image = Vec::new();
123
124        {
125            let mut base64_encoder = EncoderWriter::new(
126                Cursor::new(&mut base64_image),
127                &base64::engine::general_purpose::STANDARD,
128            );
129
130            base64_encoder.write_all(png.as_slice()).log_err()?;
131        }
132
133        // SAFETY: The base64 encoder should not produce non-UTF8
134        let source = unsafe { String::from_utf8_unchecked(base64_image) };
135
136        Some(LanguageModelImage {
137            size: image_size,
138            source: source.into(),
139        })
140    }
141
142    pub fn estimate_tokens(&self) -> usize {
143        let width = self.size.width.0.unsigned_abs() as usize;
144        let height = self.size.height.0.unsigned_abs() as usize;
145
146        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
147        // Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
148        // so this method is more of a rough guess
149        (width * height) / 750
150    }
151}
152
153#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
154pub enum MessageContent {
155    Text(String),
156    Image(LanguageModelImage),
157}
158
159impl std::fmt::Debug for MessageContent {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
163            MessageContent::Image(i) => f
164                .debug_struct("MessageContent")
165                .field("image", &i.source.len())
166                .finish(),
167        }
168    }
169}
170
171impl MessageContent {
172    pub fn as_string(&self) -> &str {
173        match self {
174            MessageContent::Text(s) => s.as_str(),
175            MessageContent::Image(_) => "",
176        }
177    }
178}
179
180impl From<String> for MessageContent {
181    fn from(value: String) -> Self {
182        MessageContent::Text(value)
183    }
184}
185
186impl From<&str> for MessageContent {
187    fn from(value: &str) -> Self {
188        MessageContent::Text(value.to_string())
189    }
190}
191
192#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
193pub struct LanguageModelRequestMessage {
194    pub role: Role,
195    pub content: Vec<MessageContent>,
196}
197
198impl LanguageModelRequestMessage {
199    pub fn string_contents(&self) -> String {
200        let mut string_buffer = String::new();
201        for string in self.content.iter().filter_map(|content| match content {
202            MessageContent::Text(s) => Some(s),
203            MessageContent::Image(_) => None,
204        }) {
205            string_buffer.push_str(string.as_str())
206        }
207        string_buffer
208    }
209
210    pub fn contents_empty(&self) -> bool {
211        self.content.is_empty()
212            || self
213                .content
214                .get(0)
215                .map(|content| match content {
216                    MessageContent::Text(s) => s.is_empty(),
217                    MessageContent::Image(_) => true,
218                })
219                .unwrap_or(false)
220    }
221}
222
223#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
224pub struct LanguageModelRequest {
225    pub messages: Vec<LanguageModelRequestMessage>,
226    pub stop: Vec<String>,
227    pub temperature: f32,
228}
229
230impl LanguageModelRequest {
231    pub fn into_open_ai(self, model: String) -> open_ai::Request {
232        open_ai::Request {
233            model,
234            messages: self
235                .messages
236                .into_iter()
237                .map(|msg| match msg.role {
238                    Role::User => open_ai::RequestMessage::User {
239                        content: msg.string_contents(),
240                    },
241                    Role::Assistant => open_ai::RequestMessage::Assistant {
242                        content: Some(msg.string_contents()),
243                        tool_calls: Vec::new(),
244                    },
245                    Role::System => open_ai::RequestMessage::System {
246                        content: msg.string_contents(),
247                    },
248                })
249                .collect(),
250            stream: true,
251            stop: self.stop,
252            temperature: self.temperature,
253            max_tokens: None,
254            tools: Vec::new(),
255            tool_choice: None,
256        }
257    }
258
259    pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
260        google_ai::GenerateContentRequest {
261            model,
262            contents: self
263                .messages
264                .into_iter()
265                .map(|msg| google_ai::Content {
266                    parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
267                        text: msg.string_contents(),
268                    })],
269                    role: match msg.role {
270                        Role::User => google_ai::Role::User,
271                        Role::Assistant => google_ai::Role::Model,
272                        Role::System => google_ai::Role::User, // Google AI doesn't have a system role
273                    },
274                })
275                .collect(),
276            generation_config: Some(google_ai::GenerationConfig {
277                candidate_count: Some(1),
278                stop_sequences: Some(self.stop),
279                max_output_tokens: None,
280                temperature: Some(self.temperature as f64),
281                top_p: None,
282                top_k: None,
283            }),
284            safety_settings: None,
285        }
286    }
287
288    pub fn into_anthropic(self, model: String) -> anthropic::Request {
289        let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
290        let mut system_message = String::new();
291
292        for message in self.messages {
293            if message.contents_empty() {
294                continue;
295            }
296
297            match message.role {
298                Role::User | Role::Assistant => {
299                    if let Some(last_message) = new_messages.last_mut() {
300                        if last_message.role == message.role {
301                            // TODO: is this append done properly?
302                            last_message.content.push(MessageContent::Text(format!(
303                                "\n\n{}",
304                                message.string_contents()
305                            )));
306                            continue;
307                        }
308                    }
309
310                    new_messages.push(message);
311                }
312                Role::System => {
313                    if !system_message.is_empty() {
314                        system_message.push_str("\n\n");
315                    }
316                    system_message.push_str(&message.string_contents());
317                }
318            }
319        }
320
321        anthropic::Request {
322            model,
323            messages: new_messages
324                .into_iter()
325                .filter_map(|message| {
326                    Some(anthropic::Message {
327                        role: match message.role {
328                            Role::User => anthropic::Role::User,
329                            Role::Assistant => anthropic::Role::Assistant,
330                            Role::System => return None,
331                        },
332                        content: message
333                            .content
334                            .into_iter()
335                            // TODO: filter out the empty messages in the message construction step
336                            .filter_map(|content| match content {
337                                MessageContent::Text(t) if !t.is_empty() => {
338                                    Some(anthropic::Content::Text { text: t })
339                                }
340                                MessageContent::Image(i) => Some(anthropic::Content::Image {
341                                    source: anthropic::ImageSource {
342                                        source_type: "base64".to_string(),
343                                        media_type: "image/png".to_string(),
344                                        data: i.source.to_string(),
345                                    },
346                                }),
347                                _ => None,
348                            })
349                            .collect(),
350                    })
351                })
352                .collect(),
353            max_tokens: 4092,
354            system: Some(system_message),
355            tools: Vec::new(),
356            tool_choice: None,
357            metadata: None,
358            stop_sequences: Vec::new(),
359            temperature: None,
360            top_k: None,
361            top_p: None,
362        }
363    }
364}
365
366#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
367pub struct LanguageModelResponseMessage {
368    pub role: Option<Role>,
369    pub content: Option<String>,
370}