1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use crate::LanguageModelToolUse;
5use base64::write::EncoderWriter;
6use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
7use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
8use serde::{Deserialize, Serialize};
9use ui::{px, SharedString};
10use util::ResultExt;
11
12#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
13pub struct LanguageModelImage {
14 /// A base64-encoded PNG image.
15 pub source: SharedString,
16 size: Size<DevicePixels>,
17}
18
19impl std::fmt::Debug for LanguageModelImage {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("LanguageModelImage")
22 .field("source", &format!("<{} bytes>", self.source.len()))
23 .field("size", &self.size)
24 .finish()
25 }
26}
27
28/// Anthropic wants uploaded images to be smaller than this in both dimensions.
29const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
30
31impl LanguageModelImage {
32 pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
33 cx.background_executor().spawn(async move {
34 match data.format() {
35 gpui::ImageFormat::Png
36 | gpui::ImageFormat::Jpeg
37 | gpui::ImageFormat::Webp
38 | gpui::ImageFormat::Gif => {}
39 _ => return None,
40 };
41
42 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
43 let (width, height) = image.dimensions();
44 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
45
46 let mut base64_image = Vec::new();
47
48 {
49 let mut base64_encoder = EncoderWriter::new(
50 Cursor::new(&mut base64_image),
51 &base64::engine::general_purpose::STANDARD,
52 );
53
54 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
55 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
56 {
57 let new_bounds = ObjectFit::ScaleDown.get_bounds(
58 gpui::Bounds {
59 origin: point(px(0.0), px(0.0)),
60 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
61 },
62 image_size,
63 );
64 let image = DynamicImage::from_decoder(image).log_err()?.resize(
65 new_bounds.size.width.0 as u32,
66 new_bounds.size.height.0 as u32,
67 image::imageops::FilterType::Triangle,
68 );
69
70 let mut png = Vec::new();
71 image
72 .write_with_encoder(PngEncoder::new(&mut png))
73 .log_err()?;
74
75 base64_encoder.write_all(png.as_slice()).log_err()?;
76 } else {
77 base64_encoder.write_all(data.bytes()).log_err()?;
78 }
79 }
80
81 // SAFETY: The base64 encoder should not produce non-UTF8.
82 let source = unsafe { String::from_utf8_unchecked(base64_image) };
83
84 Some(LanguageModelImage {
85 size: image_size,
86 source: source.into(),
87 })
88 })
89 }
90
91 /// Resolves image into an LLM-ready format (base64).
92 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
93 let image_size = data.size(0);
94
95 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
96 // Convert from BGRA to RGBA.
97 for pixel in bytes.chunks_exact_mut(4) {
98 pixel.swap(2, 0);
99 }
100 let mut image = image::RgbaImage::from_vec(
101 image_size.width.0 as u32,
102 image_size.height.0 as u32,
103 bytes,
104 )
105 .expect("We already know this works");
106
107 // https://docs.anthropic.com/en/docs/build-with-claude/vision
108 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
109 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
110 {
111 let new_bounds = ObjectFit::ScaleDown.get_bounds(
112 gpui::Bounds {
113 origin: point(px(0.0), px(0.0)),
114 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
115 },
116 image_size,
117 );
118
119 image = resize(
120 &image,
121 new_bounds.size.width.0 as u32,
122 new_bounds.size.height.0 as u32,
123 image::imageops::FilterType::Triangle,
124 );
125 }
126
127 let mut png = Vec::new();
128
129 image
130 .write_with_encoder(PngEncoder::new(&mut png))
131 .log_err()?;
132
133 let mut base64_image = Vec::new();
134
135 {
136 let mut base64_encoder = EncoderWriter::new(
137 Cursor::new(&mut base64_image),
138 &base64::engine::general_purpose::STANDARD,
139 );
140
141 base64_encoder.write_all(png.as_slice()).log_err()?;
142 }
143
144 // SAFETY: The base64 encoder should not produce non-UTF8.
145 let source = unsafe { String::from_utf8_unchecked(base64_image) };
146
147 Some(LanguageModelImage {
148 size: image_size,
149 source: source.into(),
150 })
151 }
152
153 pub fn estimate_tokens(&self) -> usize {
154 let width = self.size.width.0.unsigned_abs() as usize;
155 let height = self.size.height.0.unsigned_abs() as usize;
156
157 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
158 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
159 // so this method is more of a rough guess.
160 (width * height) / 750
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
165pub struct LanguageModelToolResult {
166 pub tool_use_id: String,
167 pub is_error: bool,
168 pub content: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
172pub enum MessageContent {
173 Text(String),
174 Image(LanguageModelImage),
175 ToolUse(LanguageModelToolUse),
176 ToolResult(LanguageModelToolResult),
177}
178
179impl From<String> for MessageContent {
180 fn from(value: String) -> Self {
181 MessageContent::Text(value)
182 }
183}
184
185impl From<&str> for MessageContent {
186 fn from(value: &str) -> Self {
187 MessageContent::Text(value.to_string())
188 }
189}
190
191#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
192pub struct LanguageModelRequestMessage {
193 pub role: Role,
194 pub content: Vec<MessageContent>,
195 pub cache: bool,
196}
197
198impl LanguageModelRequestMessage {
199 pub fn string_contents(&self) -> String {
200 let mut string_buffer = String::new();
201 for string in self.content.iter().filter_map(|content| match content {
202 MessageContent::Text(text) => Some(text),
203 MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
204 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
205 }) {
206 string_buffer.push_str(string.as_str())
207 }
208 string_buffer
209 }
210
211 pub fn contents_empty(&self) -> bool {
212 self.content.is_empty()
213 || self
214 .content
215 .first()
216 .map(|content| match content {
217 MessageContent::Text(text) => text.trim().is_empty(),
218 MessageContent::ToolResult(tool_result) => {
219 tool_result.content.trim().is_empty()
220 }
221 MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
222 })
223 .unwrap_or(false)
224 }
225}
226
227#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
228pub struct LanguageModelRequestTool {
229 pub name: String,
230 pub description: String,
231 pub input_schema: serde_json::Value,
232}
233
234#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
235pub struct LanguageModelRequest {
236 pub messages: Vec<LanguageModelRequestMessage>,
237 pub tools: Vec<LanguageModelRequestTool>,
238 pub stop: Vec<String>,
239 pub temperature: f32,
240}
241
242impl LanguageModelRequest {
243 pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
244 let stream = !model.starts_with("o1-");
245 open_ai::Request {
246 model,
247 messages: self
248 .messages
249 .into_iter()
250 .map(|msg| match msg.role {
251 Role::User => open_ai::RequestMessage::User {
252 content: msg.string_contents(),
253 },
254 Role::Assistant => open_ai::RequestMessage::Assistant {
255 content: Some(msg.string_contents()),
256 tool_calls: Vec::new(),
257 },
258 Role::System => open_ai::RequestMessage::System {
259 content: msg.string_contents(),
260 },
261 })
262 .collect(),
263 stream,
264 stop: self.stop,
265 temperature: self.temperature,
266 max_tokens: max_output_tokens,
267 tools: Vec::new(),
268 tool_choice: None,
269 }
270 }
271
272 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
273 google_ai::GenerateContentRequest {
274 model,
275 contents: self
276 .messages
277 .into_iter()
278 .map(|msg| google_ai::Content {
279 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
280 text: msg.string_contents(),
281 })],
282 role: match msg.role {
283 Role::User => google_ai::Role::User,
284 Role::Assistant => google_ai::Role::Model,
285 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
286 },
287 })
288 .collect(),
289 generation_config: Some(google_ai::GenerationConfig {
290 candidate_count: Some(1),
291 stop_sequences: Some(self.stop),
292 max_output_tokens: None,
293 temperature: Some(self.temperature as f64),
294 top_p: None,
295 top_k: None,
296 }),
297 safety_settings: None,
298 }
299 }
300
301 pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
302 let mut new_messages: Vec<anthropic::Message> = Vec::new();
303 let mut system_message = String::new();
304
305 for message in self.messages {
306 if message.contents_empty() {
307 continue;
308 }
309
310 match message.role {
311 Role::User | Role::Assistant => {
312 let cache_control = if message.cache {
313 Some(anthropic::CacheControl {
314 cache_type: anthropic::CacheControlType::Ephemeral,
315 })
316 } else {
317 None
318 };
319 let anthropic_message_content: Vec<anthropic::RequestContent> = message
320 .content
321 .into_iter()
322 .filter_map(|content| match content {
323 MessageContent::Text(text) => {
324 if !text.is_empty() {
325 Some(anthropic::RequestContent::Text {
326 text,
327 cache_control,
328 })
329 } else {
330 None
331 }
332 }
333 MessageContent::Image(image) => {
334 Some(anthropic::RequestContent::Image {
335 source: anthropic::ImageSource {
336 source_type: "base64".to_string(),
337 media_type: "image/png".to_string(),
338 data: image.source.to_string(),
339 },
340 cache_control,
341 })
342 }
343 MessageContent::ToolUse(tool_use) => {
344 Some(anthropic::RequestContent::ToolUse {
345 id: tool_use.id,
346 name: tool_use.name,
347 input: tool_use.input,
348 cache_control,
349 })
350 }
351 MessageContent::ToolResult(tool_result) => {
352 Some(anthropic::RequestContent::ToolResult {
353 tool_use_id: tool_result.tool_use_id,
354 is_error: tool_result.is_error,
355 content: tool_result.content,
356 cache_control,
357 })
358 }
359 })
360 .collect();
361 let anthropic_role = match message.role {
362 Role::User => anthropic::Role::User,
363 Role::Assistant => anthropic::Role::Assistant,
364 Role::System => unreachable!("System role should never occur here"),
365 };
366 if let Some(last_message) = new_messages.last_mut() {
367 if last_message.role == anthropic_role {
368 last_message.content.extend(anthropic_message_content);
369 continue;
370 }
371 }
372 new_messages.push(anthropic::Message {
373 role: anthropic_role,
374 content: anthropic_message_content,
375 });
376 }
377 Role::System => {
378 if !system_message.is_empty() {
379 system_message.push_str("\n\n");
380 }
381 system_message.push_str(&message.string_contents());
382 }
383 }
384 }
385
386 anthropic::Request {
387 model,
388 messages: new_messages,
389 max_tokens: max_output_tokens,
390 system: Some(system_message),
391 tools: self
392 .tools
393 .into_iter()
394 .map(|tool| anthropic::Tool {
395 name: tool.name,
396 description: tool.description,
397 input_schema: tool.input_schema,
398 })
399 .collect(),
400 tool_choice: None,
401 metadata: None,
402 stop_sequences: Vec::new(),
403 temperature: Some(self.temperature),
404 top_k: None,
405 top_p: None,
406 }
407 }
408}
409
410#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
411pub struct LanguageModelResponseMessage {
412 pub role: Option<Role>,
413 pub content: Option<String>,
414}