1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use base64::write::EncoderWriter;
5use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
6use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
7use serde::{Deserialize, Serialize};
8use ui::{px, SharedString};
9use util::ResultExt;
10
11#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
12pub struct LanguageModelImage {
13 // A base64 encoded PNG image
14 pub source: SharedString,
15 size: Size<DevicePixels>,
16}
17
18const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
19
20impl LanguageModelImage {
21 pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
22 cx.background_executor().spawn(async move {
23 match data.format() {
24 gpui::ImageFormat::Png
25 | gpui::ImageFormat::Jpeg
26 | gpui::ImageFormat::Webp
27 | gpui::ImageFormat::Gif => {}
28 _ => return None,
29 };
30
31 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
32 let (width, height) = image.dimensions();
33 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
34
35 let mut base64_image = Vec::new();
36
37 {
38 let mut base64_encoder = EncoderWriter::new(
39 Cursor::new(&mut base64_image),
40 &base64::engine::general_purpose::STANDARD,
41 );
42
43 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
44 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
45 {
46 let new_bounds = ObjectFit::ScaleDown.get_bounds(
47 gpui::Bounds {
48 origin: point(px(0.0), px(0.0)),
49 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
50 },
51 image_size,
52 );
53 let image = DynamicImage::from_decoder(image).log_err()?.resize(
54 new_bounds.size.width.0 as u32,
55 new_bounds.size.height.0 as u32,
56 image::imageops::FilterType::Triangle,
57 );
58
59 let mut png = Vec::new();
60 image
61 .write_with_encoder(PngEncoder::new(&mut png))
62 .log_err()?;
63
64 base64_encoder.write_all(png.as_slice()).log_err()?;
65 } else {
66 base64_encoder.write_all(data.bytes()).log_err()?;
67 }
68 }
69
70 // SAFETY: The base64 encoder should not produce non-UTF8
71 let source = unsafe { String::from_utf8_unchecked(base64_image) };
72
73 Some(LanguageModelImage {
74 size: image_size,
75 source: source.into(),
76 })
77 })
78 }
79
80 /// Resolves image into an LLM-ready format (base64)
81 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
82 let image_size = data.size(0);
83
84 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
85 // Convert from BGRA to RGBA.
86 for pixel in bytes.chunks_exact_mut(4) {
87 pixel.swap(2, 0);
88 }
89 let mut image = image::RgbaImage::from_vec(
90 image_size.width.0 as u32,
91 image_size.height.0 as u32,
92 bytes,
93 )
94 .expect("We already know this works");
95
96 // https://docs.anthropic.com/en/docs/build-with-claude/vision
97 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
98 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
99 {
100 let new_bounds = ObjectFit::ScaleDown.get_bounds(
101 gpui::Bounds {
102 origin: point(px(0.0), px(0.0)),
103 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
104 },
105 image_size,
106 );
107
108 image = resize(
109 &image,
110 new_bounds.size.width.0 as u32,
111 new_bounds.size.height.0 as u32,
112 image::imageops::FilterType::Triangle,
113 );
114 }
115
116 let mut png = Vec::new();
117
118 image
119 .write_with_encoder(PngEncoder::new(&mut png))
120 .log_err()?;
121
122 let mut base64_image = Vec::new();
123
124 {
125 let mut base64_encoder = EncoderWriter::new(
126 Cursor::new(&mut base64_image),
127 &base64::engine::general_purpose::STANDARD,
128 );
129
130 base64_encoder.write_all(png.as_slice()).log_err()?;
131 }
132
133 // SAFETY: The base64 encoder should not produce non-UTF8
134 let source = unsafe { String::from_utf8_unchecked(base64_image) };
135
136 Some(LanguageModelImage {
137 size: image_size,
138 source: source.into(),
139 })
140 }
141
142 pub fn estimate_tokens(&self) -> usize {
143 let width = self.size.width.0.unsigned_abs() as usize;
144 let height = self.size.height.0.unsigned_abs() as usize;
145
146 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
147 // Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
148 // so this method is more of a rough guess
149 (width * height) / 750
150 }
151}
152
153#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
154pub enum MessageContent {
155 Text(String),
156 Image(LanguageModelImage),
157}
158
159impl std::fmt::Debug for MessageContent {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
163 MessageContent::Image(i) => f
164 .debug_struct("MessageContent")
165 .field("image", &i.source.len())
166 .finish(),
167 }
168 }
169}
170
171impl MessageContent {
172 pub fn as_string(&self) -> &str {
173 match self {
174 MessageContent::Text(s) => s.as_str(),
175 MessageContent::Image(_) => "",
176 }
177 }
178}
179
180impl From<String> for MessageContent {
181 fn from(value: String) -> Self {
182 MessageContent::Text(value)
183 }
184}
185
186impl From<&str> for MessageContent {
187 fn from(value: &str) -> Self {
188 MessageContent::Text(value.to_string())
189 }
190}
191
192#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
193pub struct LanguageModelRequestMessage {
194 pub role: Role,
195 pub content: Vec<MessageContent>,
196 pub cache: bool,
197}
198
199impl LanguageModelRequestMessage {
200 pub fn string_contents(&self) -> String {
201 let mut string_buffer = String::new();
202 for string in self.content.iter().filter_map(|content| match content {
203 MessageContent::Text(s) => Some(s),
204 MessageContent::Image(_) => None,
205 }) {
206 string_buffer.push_str(string.as_str())
207 }
208 string_buffer
209 }
210
211 pub fn contents_empty(&self) -> bool {
212 self.content.is_empty()
213 || self
214 .content
215 .get(0)
216 .map(|content| match content {
217 MessageContent::Text(s) => s.trim().is_empty(),
218 MessageContent::Image(_) => true,
219 })
220 .unwrap_or(false)
221 }
222}
223
224#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
225pub struct LanguageModelRequestTool {
226 pub name: String,
227 pub description: String,
228 pub input_schema: serde_json::Value,
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
232pub struct LanguageModelRequest {
233 pub messages: Vec<LanguageModelRequestMessage>,
234 pub tools: Vec<LanguageModelRequestTool>,
235 pub stop: Vec<String>,
236 pub temperature: f32,
237}
238
239impl LanguageModelRequest {
240 pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
241 open_ai::Request {
242 model,
243 messages: self
244 .messages
245 .into_iter()
246 .map(|msg| match msg.role {
247 Role::User => open_ai::RequestMessage::User {
248 content: msg.string_contents(),
249 },
250 Role::Assistant => open_ai::RequestMessage::Assistant {
251 content: Some(msg.string_contents()),
252 tool_calls: Vec::new(),
253 },
254 Role::System => open_ai::RequestMessage::System {
255 content: msg.string_contents(),
256 },
257 })
258 .collect(),
259 stream: true,
260 stop: self.stop,
261 temperature: self.temperature,
262 max_tokens: max_output_tokens,
263 tools: Vec::new(),
264 tool_choice: None,
265 }
266 }
267
268 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
269 google_ai::GenerateContentRequest {
270 model,
271 contents: self
272 .messages
273 .into_iter()
274 .map(|msg| google_ai::Content {
275 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
276 text: msg.string_contents(),
277 })],
278 role: match msg.role {
279 Role::User => google_ai::Role::User,
280 Role::Assistant => google_ai::Role::Model,
281 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
282 },
283 })
284 .collect(),
285 generation_config: Some(google_ai::GenerationConfig {
286 candidate_count: Some(1),
287 stop_sequences: Some(self.stop),
288 max_output_tokens: None,
289 temperature: Some(self.temperature as f64),
290 top_p: None,
291 top_k: None,
292 }),
293 safety_settings: None,
294 }
295 }
296
297 pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
298 let mut new_messages: Vec<anthropic::Message> = Vec::new();
299 let mut system_message = String::new();
300
301 for message in self.messages {
302 if message.contents_empty() {
303 continue;
304 }
305
306 match message.role {
307 Role::User | Role::Assistant => {
308 let cache_control = if message.cache {
309 Some(anthropic::CacheControl {
310 cache_type: anthropic::CacheControlType::Ephemeral,
311 })
312 } else {
313 None
314 };
315 let anthropic_message_content: Vec<anthropic::RequestContent> = message
316 .content
317 .into_iter()
318 .filter_map(|content| match content {
319 MessageContent::Text(t) if !t.is_empty() => {
320 Some(anthropic::RequestContent::Text {
321 text: t,
322 cache_control,
323 })
324 }
325 MessageContent::Image(i) => Some(anthropic::RequestContent::Image {
326 source: anthropic::ImageSource {
327 source_type: "base64".to_string(),
328 media_type: "image/png".to_string(),
329 data: i.source.to_string(),
330 },
331 cache_control,
332 }),
333 _ => None,
334 })
335 .collect();
336 let anthropic_role = match message.role {
337 Role::User => anthropic::Role::User,
338 Role::Assistant => anthropic::Role::Assistant,
339 Role::System => unreachable!("System role should never occur here"),
340 };
341 if let Some(last_message) = new_messages.last_mut() {
342 if last_message.role == anthropic_role {
343 last_message.content.extend(anthropic_message_content);
344 continue;
345 }
346 }
347 new_messages.push(anthropic::Message {
348 role: anthropic_role,
349 content: anthropic_message_content,
350 });
351 }
352 Role::System => {
353 if !system_message.is_empty() {
354 system_message.push_str("\n\n");
355 }
356 system_message.push_str(&message.string_contents());
357 }
358 }
359 }
360
361 anthropic::Request {
362 model,
363 messages: new_messages,
364 max_tokens: max_output_tokens,
365 system: Some(system_message),
366 tools: self
367 .tools
368 .into_iter()
369 .map(|tool| anthropic::Tool {
370 name: tool.name,
371 description: tool.description,
372 input_schema: tool.input_schema,
373 })
374 .collect(),
375 tool_choice: None,
376 metadata: None,
377 stop_sequences: Vec::new(),
378 temperature: None,
379 top_k: None,
380 top_p: None,
381 }
382 }
383}
384
385#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
386pub struct LanguageModelResponseMessage {
387 pub role: Option<Role>,
388 pub content: Option<String>,
389}