1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use crate::LanguageModelToolUse;
5use base64::write::EncoderWriter;
6use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
7use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
8use serde::{Deserialize, Serialize};
9use ui::{px, SharedString};
10use util::ResultExt;
11
12#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
13pub struct LanguageModelImage {
14 /// A base64-encoded PNG image.
15 pub source: SharedString,
16 size: Size<DevicePixels>,
17}
18
19impl std::fmt::Debug for LanguageModelImage {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("LanguageModelImage")
22 .field("source", &format!("<{} bytes>", self.source.len()))
23 .field("size", &self.size)
24 .finish()
25 }
26}
27
28/// Anthropic wants uploaded images to be smaller than this in both dimensions.
29const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
30
31impl LanguageModelImage {
32 pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
33 cx.background_executor().spawn(async move {
34 match data.format() {
35 gpui::ImageFormat::Png
36 | gpui::ImageFormat::Jpeg
37 | gpui::ImageFormat::Webp
38 | gpui::ImageFormat::Gif => {}
39 _ => return None,
40 };
41
42 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
43 let (width, height) = image.dimensions();
44 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
45
46 let mut base64_image = Vec::new();
47
48 {
49 let mut base64_encoder = EncoderWriter::new(
50 Cursor::new(&mut base64_image),
51 &base64::engine::general_purpose::STANDARD,
52 );
53
54 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
55 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
56 {
57 let new_bounds = ObjectFit::ScaleDown.get_bounds(
58 gpui::Bounds {
59 origin: point(px(0.0), px(0.0)),
60 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
61 },
62 image_size,
63 );
64 let image = DynamicImage::from_decoder(image).log_err()?.resize(
65 new_bounds.size.width.0 as u32,
66 new_bounds.size.height.0 as u32,
67 image::imageops::FilterType::Triangle,
68 );
69
70 let mut png = Vec::new();
71 image
72 .write_with_encoder(PngEncoder::new(&mut png))
73 .log_err()?;
74
75 base64_encoder.write_all(png.as_slice()).log_err()?;
76 } else {
77 base64_encoder.write_all(data.bytes()).log_err()?;
78 }
79 }
80
81 // SAFETY: The base64 encoder should not produce non-UTF8.
82 let source = unsafe { String::from_utf8_unchecked(base64_image) };
83
84 Some(LanguageModelImage {
85 size: image_size,
86 source: source.into(),
87 })
88 })
89 }
90
91 /// Resolves image into an LLM-ready format (base64).
92 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
93 let image_size = data.size(0);
94
95 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
96 // Convert from BGRA to RGBA.
97 for pixel in bytes.chunks_exact_mut(4) {
98 pixel.swap(2, 0);
99 }
100 let mut image = image::RgbaImage::from_vec(
101 image_size.width.0 as u32,
102 image_size.height.0 as u32,
103 bytes,
104 )
105 .expect("We already know this works");
106
107 // https://docs.anthropic.com/en/docs/build-with-claude/vision
108 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
109 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
110 {
111 let new_bounds = ObjectFit::ScaleDown.get_bounds(
112 gpui::Bounds {
113 origin: point(px(0.0), px(0.0)),
114 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
115 },
116 image_size,
117 );
118
119 image = resize(
120 &image,
121 new_bounds.size.width.0 as u32,
122 new_bounds.size.height.0 as u32,
123 image::imageops::FilterType::Triangle,
124 );
125 }
126
127 let mut png = Vec::new();
128
129 image
130 .write_with_encoder(PngEncoder::new(&mut png))
131 .log_err()?;
132
133 let mut base64_image = Vec::new();
134
135 {
136 let mut base64_encoder = EncoderWriter::new(
137 Cursor::new(&mut base64_image),
138 &base64::engine::general_purpose::STANDARD,
139 );
140
141 base64_encoder.write_all(png.as_slice()).log_err()?;
142 }
143
144 // SAFETY: The base64 encoder should not produce non-UTF8.
145 let source = unsafe { String::from_utf8_unchecked(base64_image) };
146
147 Some(LanguageModelImage {
148 size: image_size,
149 source: source.into(),
150 })
151 }
152
153 pub fn estimate_tokens(&self) -> usize {
154 let width = self.size.width.0.unsigned_abs() as usize;
155 let height = self.size.height.0.unsigned_abs() as usize;
156
157 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
158 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
159 // so this method is more of a rough guess.
160 (width * height) / 750
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
165pub struct LanguageModelToolResult {
166 pub tool_use_id: String,
167 pub is_error: bool,
168 pub content: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
172pub enum MessageContent {
173 Text(String),
174 Image(LanguageModelImage),
175 ToolUse(LanguageModelToolUse),
176 ToolResult(LanguageModelToolResult),
177}
178
179impl From<String> for MessageContent {
180 fn from(value: String) -> Self {
181 MessageContent::Text(value)
182 }
183}
184
185impl From<&str> for MessageContent {
186 fn from(value: &str) -> Self {
187 MessageContent::Text(value.to_string())
188 }
189}
190
191#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
192pub struct LanguageModelRequestMessage {
193 pub role: Role,
194 pub content: Vec<MessageContent>,
195 pub cache: bool,
196}
197
198impl LanguageModelRequestMessage {
199 pub fn string_contents(&self) -> String {
200 let mut string_buffer = String::new();
201 for string in self.content.iter().filter_map(|content| match content {
202 MessageContent::Text(text) => Some(text),
203 MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
204 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
205 }) {
206 string_buffer.push_str(string.as_str())
207 }
208 string_buffer
209 }
210
211 pub fn contents_empty(&self) -> bool {
212 self.content.is_empty()
213 || self
214 .content
215 .first()
216 .map(|content| match content {
217 MessageContent::Text(text) => text.trim().is_empty(),
218 MessageContent::ToolResult(tool_result) => {
219 tool_result.content.trim().is_empty()
220 }
221 MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
222 })
223 .unwrap_or(false)
224 }
225}
226
227#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
228pub struct LanguageModelRequestTool {
229 pub name: String,
230 pub description: String,
231 pub input_schema: serde_json::Value,
232}
233
234#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
235pub struct LanguageModelRequest {
236 pub messages: Vec<LanguageModelRequestMessage>,
237 pub tools: Vec<LanguageModelRequestTool>,
238 pub stop: Vec<String>,
239 pub temperature: f32,
240}
241
242impl LanguageModelRequest {
243 pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
244 open_ai::Request {
245 model,
246 messages: self
247 .messages
248 .into_iter()
249 .map(|msg| match msg.role {
250 Role::User => open_ai::RequestMessage::User {
251 content: msg.string_contents(),
252 },
253 Role::Assistant => open_ai::RequestMessage::Assistant {
254 content: Some(msg.string_contents()),
255 tool_calls: Vec::new(),
256 },
257 Role::System => open_ai::RequestMessage::System {
258 content: msg.string_contents(),
259 },
260 })
261 .collect(),
262 stream: true,
263 stop: self.stop,
264 temperature: self.temperature,
265 max_tokens: max_output_tokens,
266 tools: Vec::new(),
267 tool_choice: None,
268 }
269 }
270
271 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
272 google_ai::GenerateContentRequest {
273 model,
274 contents: self
275 .messages
276 .into_iter()
277 .map(|msg| google_ai::Content {
278 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
279 text: msg.string_contents(),
280 })],
281 role: match msg.role {
282 Role::User => google_ai::Role::User,
283 Role::Assistant => google_ai::Role::Model,
284 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
285 },
286 })
287 .collect(),
288 generation_config: Some(google_ai::GenerationConfig {
289 candidate_count: Some(1),
290 stop_sequences: Some(self.stop),
291 max_output_tokens: None,
292 temperature: Some(self.temperature as f64),
293 top_p: None,
294 top_k: None,
295 }),
296 safety_settings: None,
297 }
298 }
299
300 pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
301 let mut new_messages: Vec<anthropic::Message> = Vec::new();
302 let mut system_message = String::new();
303
304 for message in self.messages {
305 if message.contents_empty() {
306 continue;
307 }
308
309 match message.role {
310 Role::User | Role::Assistant => {
311 let cache_control = if message.cache {
312 Some(anthropic::CacheControl {
313 cache_type: anthropic::CacheControlType::Ephemeral,
314 })
315 } else {
316 None
317 };
318 let anthropic_message_content: Vec<anthropic::RequestContent> = message
319 .content
320 .into_iter()
321 .filter_map(|content| match content {
322 MessageContent::Text(text) => {
323 if !text.is_empty() {
324 Some(anthropic::RequestContent::Text {
325 text,
326 cache_control,
327 })
328 } else {
329 None
330 }
331 }
332 MessageContent::Image(image) => {
333 Some(anthropic::RequestContent::Image {
334 source: anthropic::ImageSource {
335 source_type: "base64".to_string(),
336 media_type: "image/png".to_string(),
337 data: image.source.to_string(),
338 },
339 cache_control,
340 })
341 }
342 MessageContent::ToolUse(tool_use) => {
343 Some(anthropic::RequestContent::ToolUse {
344 id: tool_use.id,
345 name: tool_use.name,
346 input: tool_use.input,
347 cache_control,
348 })
349 }
350 MessageContent::ToolResult(tool_result) => {
351 Some(anthropic::RequestContent::ToolResult {
352 tool_use_id: tool_result.tool_use_id,
353 is_error: tool_result.is_error,
354 content: tool_result.content,
355 cache_control,
356 })
357 }
358 })
359 .collect();
360 let anthropic_role = match message.role {
361 Role::User => anthropic::Role::User,
362 Role::Assistant => anthropic::Role::Assistant,
363 Role::System => unreachable!("System role should never occur here"),
364 };
365 if let Some(last_message) = new_messages.last_mut() {
366 if last_message.role == anthropic_role {
367 last_message.content.extend(anthropic_message_content);
368 continue;
369 }
370 }
371 new_messages.push(anthropic::Message {
372 role: anthropic_role,
373 content: anthropic_message_content,
374 });
375 }
376 Role::System => {
377 if !system_message.is_empty() {
378 system_message.push_str("\n\n");
379 }
380 system_message.push_str(&message.string_contents());
381 }
382 }
383 }
384
385 anthropic::Request {
386 model,
387 messages: new_messages,
388 max_tokens: max_output_tokens,
389 system: Some(system_message),
390 tools: self
391 .tools
392 .into_iter()
393 .map(|tool| anthropic::Tool {
394 name: tool.name,
395 description: tool.description,
396 input_schema: tool.input_schema,
397 })
398 .collect(),
399 tool_choice: None,
400 metadata: None,
401 stop_sequences: Vec::new(),
402 temperature: None,
403 top_k: None,
404 top_p: None,
405 }
406 }
407}
408
409#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
410pub struct LanguageModelResponseMessage {
411 pub role: Option<Role>,
412 pub content: Option<String>,
413}