1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use base64::write::EncoderWriter;
5use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
6use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
7use serde::{Deserialize, Serialize};
8use ui::{px, SharedString};
9use util::ResultExt;
10
11#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
12pub struct LanguageModelImage {
13 // A base64 encoded PNG image
14 pub source: SharedString,
15 size: Size<DevicePixels>,
16}
17
18const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
19
20impl LanguageModelImage {
21 pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
22 cx.background_executor().spawn(async move {
23 match data.format() {
24 gpui::ImageFormat::Png
25 | gpui::ImageFormat::Jpeg
26 | gpui::ImageFormat::Webp
27 | gpui::ImageFormat::Gif => {}
28 _ => return None,
29 };
30
31 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
32 let (width, height) = image.dimensions();
33 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
34
35 let mut base64_image = Vec::new();
36
37 {
38 let mut base64_encoder = EncoderWriter::new(
39 Cursor::new(&mut base64_image),
40 &base64::engine::general_purpose::STANDARD,
41 );
42
43 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
44 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
45 {
46 let new_bounds = ObjectFit::ScaleDown.get_bounds(
47 gpui::Bounds {
48 origin: point(px(0.0), px(0.0)),
49 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
50 },
51 image_size,
52 );
53 let image = DynamicImage::from_decoder(image).log_err()?.resize(
54 new_bounds.size.width.0 as u32,
55 new_bounds.size.height.0 as u32,
56 image::imageops::FilterType::Triangle,
57 );
58
59 let mut png = Vec::new();
60 image
61 .write_with_encoder(PngEncoder::new(&mut png))
62 .log_err()?;
63
64 base64_encoder.write_all(png.as_slice()).log_err()?;
65 } else {
66 base64_encoder.write_all(data.bytes()).log_err()?;
67 }
68 }
69
70 // SAFETY: The base64 encoder should not produce non-UTF8
71 let source = unsafe { String::from_utf8_unchecked(base64_image) };
72
73 Some(LanguageModelImage {
74 size: image_size,
75 source: source.into(),
76 })
77 })
78 }
79
80 /// Resolves image into an LLM-ready format (base64)
81 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
82 let image_size = data.size(0);
83
84 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
85 // Convert from BGRA to RGBA.
86 for pixel in bytes.chunks_exact_mut(4) {
87 pixel.swap(2, 0);
88 }
89 let mut image = image::RgbaImage::from_vec(
90 image_size.width.0 as u32,
91 image_size.height.0 as u32,
92 bytes,
93 )
94 .expect("We already know this works");
95
96 // https://docs.anthropic.com/en/docs/build-with-claude/vision
97 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
98 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
99 {
100 let new_bounds = ObjectFit::ScaleDown.get_bounds(
101 gpui::Bounds {
102 origin: point(px(0.0), px(0.0)),
103 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
104 },
105 image_size,
106 );
107
108 image = resize(
109 &image,
110 new_bounds.size.width.0 as u32,
111 new_bounds.size.height.0 as u32,
112 image::imageops::FilterType::Triangle,
113 );
114 }
115
116 let mut png = Vec::new();
117
118 image
119 .write_with_encoder(PngEncoder::new(&mut png))
120 .log_err()?;
121
122 let mut base64_image = Vec::new();
123
124 {
125 let mut base64_encoder = EncoderWriter::new(
126 Cursor::new(&mut base64_image),
127 &base64::engine::general_purpose::STANDARD,
128 );
129
130 base64_encoder.write_all(png.as_slice()).log_err()?;
131 }
132
133 // SAFETY: The base64 encoder should not produce non-UTF8
134 let source = unsafe { String::from_utf8_unchecked(base64_image) };
135
136 Some(LanguageModelImage {
137 size: image_size,
138 source: source.into(),
139 })
140 }
141
142 pub fn estimate_tokens(&self) -> usize {
143 let width = self.size.width.0.unsigned_abs() as usize;
144 let height = self.size.height.0.unsigned_abs() as usize;
145
146 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
147 // Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
148 // so this method is more of a rough guess
149 (width * height) / 750
150 }
151}
152
153#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
154pub enum MessageContent {
155 Text(String),
156 Image(LanguageModelImage),
157}
158
159impl std::fmt::Debug for MessageContent {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
163 MessageContent::Image(i) => f
164 .debug_struct("MessageContent")
165 .field("image", &i.source.len())
166 .finish(),
167 }
168 }
169}
170
171impl MessageContent {
172 pub fn as_string(&self) -> &str {
173 match self {
174 MessageContent::Text(s) => s.as_str(),
175 MessageContent::Image(_) => "",
176 }
177 }
178}
179
180impl From<String> for MessageContent {
181 fn from(value: String) -> Self {
182 MessageContent::Text(value)
183 }
184}
185
186impl From<&str> for MessageContent {
187 fn from(value: &str) -> Self {
188 MessageContent::Text(value.to_string())
189 }
190}
191
192#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
193pub struct LanguageModelRequestMessage {
194 pub role: Role,
195 pub content: Vec<MessageContent>,
196 pub cache: bool,
197}
198
199impl LanguageModelRequestMessage {
200 pub fn string_contents(&self) -> String {
201 let mut string_buffer = String::new();
202 for string in self.content.iter().filter_map(|content| match content {
203 MessageContent::Text(s) => Some(s),
204 MessageContent::Image(_) => None,
205 }) {
206 string_buffer.push_str(string.as_str())
207 }
208 string_buffer
209 }
210
211 pub fn contents_empty(&self) -> bool {
212 self.content.is_empty()
213 || self
214 .content
215 .get(0)
216 .map(|content| match content {
217 MessageContent::Text(s) => s.trim().is_empty(),
218 MessageContent::Image(_) => true,
219 })
220 .unwrap_or(false)
221 }
222}
223
224#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
225pub struct LanguageModelRequest {
226 pub messages: Vec<LanguageModelRequestMessage>,
227 pub stop: Vec<String>,
228 pub temperature: f32,
229}
230
231impl LanguageModelRequest {
232 pub fn into_open_ai(self, model: String) -> open_ai::Request {
233 open_ai::Request {
234 model,
235 messages: self
236 .messages
237 .into_iter()
238 .map(|msg| match msg.role {
239 Role::User => open_ai::RequestMessage::User {
240 content: msg.string_contents(),
241 },
242 Role::Assistant => open_ai::RequestMessage::Assistant {
243 content: Some(msg.string_contents()),
244 tool_calls: Vec::new(),
245 },
246 Role::System => open_ai::RequestMessage::System {
247 content: msg.string_contents(),
248 },
249 })
250 .collect(),
251 stream: true,
252 stop: self.stop,
253 temperature: self.temperature,
254 max_tokens: None,
255 tools: Vec::new(),
256 tool_choice: None,
257 }
258 }
259
260 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
261 google_ai::GenerateContentRequest {
262 model,
263 contents: self
264 .messages
265 .into_iter()
266 .map(|msg| google_ai::Content {
267 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
268 text: msg.string_contents(),
269 })],
270 role: match msg.role {
271 Role::User => google_ai::Role::User,
272 Role::Assistant => google_ai::Role::Model,
273 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
274 },
275 })
276 .collect(),
277 generation_config: Some(google_ai::GenerationConfig {
278 candidate_count: Some(1),
279 stop_sequences: Some(self.stop),
280 max_output_tokens: None,
281 temperature: Some(self.temperature as f64),
282 top_p: None,
283 top_k: None,
284 }),
285 safety_settings: None,
286 }
287 }
288
289 pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
290 let mut new_messages: Vec<anthropic::Message> = Vec::new();
291 let mut system_message = String::new();
292
293 for message in self.messages {
294 if message.contents_empty() {
295 continue;
296 }
297
298 match message.role {
299 Role::User | Role::Assistant => {
300 let cache_control = if message.cache {
301 Some(anthropic::CacheControl {
302 cache_type: anthropic::CacheControlType::Ephemeral,
303 })
304 } else {
305 None
306 };
307 let anthropic_message_content: Vec<anthropic::Content> = message
308 .content
309 .into_iter()
310 .filter_map(|content| match content {
311 MessageContent::Text(t) if !t.is_empty() => {
312 Some(anthropic::Content::Text {
313 text: t,
314 cache_control,
315 })
316 }
317 MessageContent::Image(i) => Some(anthropic::Content::Image {
318 source: anthropic::ImageSource {
319 source_type: "base64".to_string(),
320 media_type: "image/png".to_string(),
321 data: i.source.to_string(),
322 },
323 cache_control,
324 }),
325 _ => None,
326 })
327 .collect();
328 let anthropic_role = match message.role {
329 Role::User => anthropic::Role::User,
330 Role::Assistant => anthropic::Role::Assistant,
331 Role::System => unreachable!("System role should never occur here"),
332 };
333 if let Some(last_message) = new_messages.last_mut() {
334 if last_message.role == anthropic_role {
335 last_message.content.extend(anthropic_message_content);
336 continue;
337 }
338 }
339 new_messages.push(anthropic::Message {
340 role: anthropic_role,
341 content: anthropic_message_content,
342 });
343 }
344 Role::System => {
345 if !system_message.is_empty() {
346 system_message.push_str("\n\n");
347 }
348 system_message.push_str(&message.string_contents());
349 }
350 }
351 }
352
353 anthropic::Request {
354 model,
355 messages: new_messages,
356 max_tokens: max_output_tokens,
357 system: Some(system_message),
358 tools: Vec::new(),
359 tool_choice: None,
360 metadata: None,
361 stop_sequences: Vec::new(),
362 temperature: None,
363 top_k: None,
364 top_p: None,
365 }
366 }
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct LanguageModelResponseMessage {
371 pub role: Option<Role>,
372 pub content: Option<String>,
373}