request.rs

  1use std::io::{Cursor, Write};
  2
  3use crate::role::Role;
  4use crate::LanguageModelToolUse;
  5use base64::write::EncoderWriter;
  6use gpui::{point, size, App, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
  7use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
  8use serde::{Deserialize, Serialize};
  9use ui::{px, SharedString};
 10use util::ResultExt;
 11
 12#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 13pub struct LanguageModelImage {
 14    /// A base64-encoded PNG image.
 15    pub source: SharedString,
 16    size: Size<DevicePixels>,
 17}
 18
 19impl std::fmt::Debug for LanguageModelImage {
 20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 21        f.debug_struct("LanguageModelImage")
 22            .field("source", &format!("<{} bytes>", self.source.len()))
 23            .field("size", &self.size)
 24            .finish()
 25    }
 26}
 27
 28/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 29const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 30
 31impl LanguageModelImage {
 32    pub fn from_image(data: Image, cx: &mut App) -> Task<Option<Self>> {
 33        cx.background_executor().spawn(async move {
 34            match data.format() {
 35                gpui::ImageFormat::Png
 36                | gpui::ImageFormat::Jpeg
 37                | gpui::ImageFormat::Webp
 38                | gpui::ImageFormat::Gif => {}
 39                _ => return None,
 40            };
 41
 42            let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
 43            let (width, height) = image.dimensions();
 44            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 45
 46            let mut base64_image = Vec::new();
 47
 48            {
 49                let mut base64_encoder = EncoderWriter::new(
 50                    Cursor::new(&mut base64_image),
 51                    &base64::engine::general_purpose::STANDARD,
 52                );
 53
 54                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 55                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 56                {
 57                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 58                        gpui::Bounds {
 59                            origin: point(px(0.0), px(0.0)),
 60                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 61                        },
 62                        image_size,
 63                    );
 64                    let image = DynamicImage::from_decoder(image).log_err()?.resize(
 65                        new_bounds.size.width.0 as u32,
 66                        new_bounds.size.height.0 as u32,
 67                        image::imageops::FilterType::Triangle,
 68                    );
 69
 70                    let mut png = Vec::new();
 71                    image
 72                        .write_with_encoder(PngEncoder::new(&mut png))
 73                        .log_err()?;
 74
 75                    base64_encoder.write_all(png.as_slice()).log_err()?;
 76                } else {
 77                    base64_encoder.write_all(data.bytes()).log_err()?;
 78                }
 79            }
 80
 81            // SAFETY: The base64 encoder should not produce non-UTF8.
 82            let source = unsafe { String::from_utf8_unchecked(base64_image) };
 83
 84            Some(LanguageModelImage {
 85                size: image_size,
 86                source: source.into(),
 87            })
 88        })
 89    }
 90
 91    /// Resolves image into an LLM-ready format (base64).
 92    pub fn from_render_image(data: &RenderImage) -> Option<Self> {
 93        let image_size = data.size(0);
 94
 95        let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
 96        // Convert from BGRA to RGBA.
 97        for pixel in bytes.chunks_exact_mut(4) {
 98            pixel.swap(2, 0);
 99        }
100        let mut image = image::RgbaImage::from_vec(
101            image_size.width.0 as u32,
102            image_size.height.0 as u32,
103            bytes,
104        )
105        .expect("We already know this works");
106
107        // https://docs.anthropic.com/en/docs/build-with-claude/vision
108        if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
109            || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
110        {
111            let new_bounds = ObjectFit::ScaleDown.get_bounds(
112                gpui::Bounds {
113                    origin: point(px(0.0), px(0.0)),
114                    size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
115                },
116                image_size,
117            );
118
119            image = resize(
120                &image,
121                new_bounds.size.width.0 as u32,
122                new_bounds.size.height.0 as u32,
123                image::imageops::FilterType::Triangle,
124            );
125        }
126
127        let mut png = Vec::new();
128
129        image
130            .write_with_encoder(PngEncoder::new(&mut png))
131            .log_err()?;
132
133        let mut base64_image = Vec::new();
134
135        {
136            let mut base64_encoder = EncoderWriter::new(
137                Cursor::new(&mut base64_image),
138                &base64::engine::general_purpose::STANDARD,
139            );
140
141            base64_encoder.write_all(png.as_slice()).log_err()?;
142        }
143
144        // SAFETY: The base64 encoder should not produce non-UTF8.
145        let source = unsafe { String::from_utf8_unchecked(base64_image) };
146
147        Some(LanguageModelImage {
148            size: image_size,
149            source: source.into(),
150        })
151    }
152
153    pub fn estimate_tokens(&self) -> usize {
154        let width = self.size.width.0.unsigned_abs() as usize;
155        let height = self.size.height.0.unsigned_abs() as usize;
156
157        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
158        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
159        // so this method is more of a rough guess.
160        (width * height) / 750
161    }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
165pub struct LanguageModelToolResult {
166    pub tool_use_id: String,
167    pub is_error: bool,
168    pub content: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
172pub enum MessageContent {
173    Text(String),
174    Image(LanguageModelImage),
175    ToolUse(LanguageModelToolUse),
176    ToolResult(LanguageModelToolResult),
177}
178
179impl From<String> for MessageContent {
180    fn from(value: String) -> Self {
181        MessageContent::Text(value)
182    }
183}
184
185impl From<&str> for MessageContent {
186    fn from(value: &str) -> Self {
187        MessageContent::Text(value.to_string())
188    }
189}
190
191#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
192pub struct LanguageModelRequestMessage {
193    pub role: Role,
194    pub content: Vec<MessageContent>,
195    pub cache: bool,
196}
197
198impl LanguageModelRequestMessage {
199    pub fn string_contents(&self) -> String {
200        let mut string_buffer = String::new();
201        for string in self.content.iter().filter_map(|content| match content {
202            MessageContent::Text(text) => Some(text),
203            MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
204            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
205        }) {
206            string_buffer.push_str(string.as_str())
207        }
208        string_buffer
209    }
210
211    pub fn contents_empty(&self) -> bool {
212        self.content.is_empty()
213            || self
214                .content
215                .first()
216                .map(|content| match content {
217                    MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
218                    MessageContent::ToolResult(tool_result) => {
219                        tool_result.content.chars().all(|c| c.is_whitespace())
220                    }
221                    MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
222                })
223                .unwrap_or(false)
224    }
225}
226
227#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
228pub struct LanguageModelRequestTool {
229    pub name: String,
230    pub description: String,
231    pub input_schema: serde_json::Value,
232}
233
234#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
235pub struct LanguageModelRequest {
236    pub messages: Vec<LanguageModelRequestMessage>,
237    pub tools: Vec<LanguageModelRequestTool>,
238    pub stop: Vec<String>,
239    pub temperature: Option<f32>,
240}
241
242impl LanguageModelRequest {
243    pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
244        let stream = !model.starts_with("o1-");
245        open_ai::Request {
246            model,
247            messages: self
248                .messages
249                .into_iter()
250                .map(|msg| match msg.role {
251                    Role::User => open_ai::RequestMessage::User {
252                        content: msg.string_contents(),
253                    },
254                    Role::Assistant => open_ai::RequestMessage::Assistant {
255                        content: Some(msg.string_contents()),
256                        tool_calls: Vec::new(),
257                    },
258                    Role::System => open_ai::RequestMessage::System {
259                        content: msg.string_contents(),
260                    },
261                })
262                .collect(),
263            stream,
264            stop: self.stop,
265            temperature: self.temperature.unwrap_or(1.0),
266            max_tokens: max_output_tokens,
267            tools: Vec::new(),
268            tool_choice: None,
269        }
270    }
271
272    pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
273        let len = self.messages.len();
274        let merged_messages =
275            self.messages
276                .into_iter()
277                .fold(Vec::with_capacity(len), |mut acc, msg| {
278                    let role = msg.role;
279                    let content = msg.string_contents();
280
281                    acc.push(match role {
282                        Role::User => mistral::RequestMessage::User { content },
283                        Role::Assistant => mistral::RequestMessage::Assistant {
284                            content: Some(content),
285                            tool_calls: Vec::new(),
286                        },
287                        Role::System => mistral::RequestMessage::System { content },
288                    });
289                    acc
290                });
291
292        mistral::Request {
293            model,
294            messages: merged_messages,
295            stream: true,
296            max_tokens: max_output_tokens,
297            temperature: self.temperature,
298            response_format: None,
299            tools: self
300                .tools
301                .into_iter()
302                .map(|tool| mistral::ToolDefinition::Function {
303                    function: mistral::FunctionDefinition {
304                        name: tool.name,
305                        description: Some(tool.description),
306                        parameters: Some(tool.input_schema),
307                    },
308                })
309                .collect(),
310        }
311    }
312
313    pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
314        google_ai::GenerateContentRequest {
315            model,
316            contents: self
317                .messages
318                .into_iter()
319                .map(|msg| google_ai::Content {
320                    parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
321                        text: msg.string_contents(),
322                    })],
323                    role: match msg.role {
324                        Role::User => google_ai::Role::User,
325                        Role::Assistant => google_ai::Role::Model,
326                        Role::System => google_ai::Role::User, // Google AI doesn't have a system role
327                    },
328                })
329                .collect(),
330            generation_config: Some(google_ai::GenerationConfig {
331                candidate_count: Some(1),
332                stop_sequences: Some(self.stop),
333                max_output_tokens: None,
334                temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
335                top_p: None,
336                top_k: None,
337            }),
338            safety_settings: None,
339        }
340    }
341
342    pub fn into_anthropic(
343        self,
344        model: String,
345        default_temperature: f32,
346        max_output_tokens: u32,
347    ) -> anthropic::Request {
348        let mut new_messages: Vec<anthropic::Message> = Vec::new();
349        let mut system_message = String::new();
350
351        for message in self.messages {
352            if message.contents_empty() {
353                continue;
354            }
355
356            match message.role {
357                Role::User | Role::Assistant => {
358                    let cache_control = if message.cache {
359                        Some(anthropic::CacheControl {
360                            cache_type: anthropic::CacheControlType::Ephemeral,
361                        })
362                    } else {
363                        None
364                    };
365                    let anthropic_message_content: Vec<anthropic::RequestContent> = message
366                        .content
367                        .into_iter()
368                        .filter_map(|content| match content {
369                            MessageContent::Text(text) => {
370                                if !text.is_empty() {
371                                    Some(anthropic::RequestContent::Text {
372                                        text,
373                                        cache_control,
374                                    })
375                                } else {
376                                    None
377                                }
378                            }
379                            MessageContent::Image(image) => {
380                                Some(anthropic::RequestContent::Image {
381                                    source: anthropic::ImageSource {
382                                        source_type: "base64".to_string(),
383                                        media_type: "image/png".to_string(),
384                                        data: image.source.to_string(),
385                                    },
386                                    cache_control,
387                                })
388                            }
389                            MessageContent::ToolUse(tool_use) => {
390                                Some(anthropic::RequestContent::ToolUse {
391                                    id: tool_use.id.to_string(),
392                                    name: tool_use.name,
393                                    input: tool_use.input,
394                                    cache_control,
395                                })
396                            }
397                            MessageContent::ToolResult(tool_result) => {
398                                Some(anthropic::RequestContent::ToolResult {
399                                    tool_use_id: tool_result.tool_use_id,
400                                    is_error: tool_result.is_error,
401                                    content: tool_result.content,
402                                    cache_control,
403                                })
404                            }
405                        })
406                        .collect();
407                    let anthropic_role = match message.role {
408                        Role::User => anthropic::Role::User,
409                        Role::Assistant => anthropic::Role::Assistant,
410                        Role::System => unreachable!("System role should never occur here"),
411                    };
412                    if let Some(last_message) = new_messages.last_mut() {
413                        if last_message.role == anthropic_role {
414                            last_message.content.extend(anthropic_message_content);
415                            continue;
416                        }
417                    }
418                    new_messages.push(anthropic::Message {
419                        role: anthropic_role,
420                        content: anthropic_message_content,
421                    });
422                }
423                Role::System => {
424                    if !system_message.is_empty() {
425                        system_message.push_str("\n\n");
426                    }
427                    system_message.push_str(&message.string_contents());
428                }
429            }
430        }
431
432        anthropic::Request {
433            model,
434            messages: new_messages,
435            max_tokens: max_output_tokens,
436            system: Some(system_message),
437            tools: self
438                .tools
439                .into_iter()
440                .map(|tool| anthropic::Tool {
441                    name: tool.name,
442                    description: tool.description,
443                    input_schema: tool.input_schema,
444                })
445                .collect(),
446            tool_choice: None,
447            metadata: None,
448            stop_sequences: Vec::new(),
449            temperature: self.temperature.or(Some(default_temperature)),
450            top_k: None,
451            top_p: None,
452        }
453    }
454
455    pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
456        let is_reasoner = model == "deepseek-reasoner";
457
458        let len = self.messages.len();
459        let merged_messages =
460            self.messages
461                .into_iter()
462                .fold(Vec::with_capacity(len), |mut acc, msg| {
463                    let role = msg.role;
464                    let content = msg.string_contents();
465
466                    if is_reasoner {
467                        if let Some(last_msg) = acc.last_mut() {
468                            match (last_msg, role) {
469                                (deepseek::RequestMessage::User { content: last }, Role::User) => {
470                                    last.push(' ');
471                                    last.push_str(&content);
472                                    return acc;
473                                }
474
475                                (
476                                    deepseek::RequestMessage::Assistant {
477                                        content: last_content,
478                                        ..
479                                    },
480                                    Role::Assistant,
481                                ) => {
482                                    *last_content = last_content
483                                        .take()
484                                        .map(|c| {
485                                            let mut s =
486                                                String::with_capacity(c.len() + content.len() + 1);
487                                            s.push_str(&c);
488                                            s.push(' ');
489                                            s.push_str(&content);
490                                            s
491                                        })
492                                        .or(Some(content));
493
494                                    return acc;
495                                }
496                                _ => {}
497                            }
498                        }
499                    }
500
501                    acc.push(match role {
502                        Role::User => deepseek::RequestMessage::User { content },
503                        Role::Assistant => deepseek::RequestMessage::Assistant {
504                            content: Some(content),
505                            tool_calls: Vec::new(),
506                        },
507                        Role::System => deepseek::RequestMessage::System { content },
508                    });
509                    acc
510                });
511
512        deepseek::Request {
513            model,
514            messages: merged_messages,
515            stream: true,
516            max_tokens: max_output_tokens,
517            temperature: if is_reasoner { None } else { self.temperature },
518            response_format: None,
519            tools: self
520                .tools
521                .into_iter()
522                .map(|tool| deepseek::ToolDefinition::Function {
523                    function: deepseek::FunctionDefinition {
524                        name: tool.name,
525                        description: Some(tool.description),
526                        parameters: Some(tool.input_schema),
527                    },
528                })
529                .collect(),
530        }
531    }
532}
533
534#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
535pub struct LanguageModelResponseMessage {
536    pub role: Option<Role>,
537    pub content: Option<String>,
538}