1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use base64::write::EncoderWriter;
5use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
6use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
7use serde::{Deserialize, Serialize};
8use ui::{px, SharedString};
9use util::ResultExt;
10
11#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
12pub struct LanguageModelImage {
13 /// A base64-encoded PNG image.
14 pub source: SharedString,
15 size: Size<DevicePixels>,
16}
17
18impl std::fmt::Debug for LanguageModelImage {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("LanguageModelImage")
21 .field("source", &format!("<{} bytes>", self.source.len()))
22 .field("size", &self.size)
23 .finish()
24 }
25}
26
27/// Anthropic wants uploaded images to be smaller than this in both dimensions.
28const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
29
30impl LanguageModelImage {
31 pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
32 cx.background_executor().spawn(async move {
33 match data.format() {
34 gpui::ImageFormat::Png
35 | gpui::ImageFormat::Jpeg
36 | gpui::ImageFormat::Webp
37 | gpui::ImageFormat::Gif => {}
38 _ => return None,
39 };
40
41 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
42 let (width, height) = image.dimensions();
43 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
44
45 let mut base64_image = Vec::new();
46
47 {
48 let mut base64_encoder = EncoderWriter::new(
49 Cursor::new(&mut base64_image),
50 &base64::engine::general_purpose::STANDARD,
51 );
52
53 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
54 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
55 {
56 let new_bounds = ObjectFit::ScaleDown.get_bounds(
57 gpui::Bounds {
58 origin: point(px(0.0), px(0.0)),
59 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
60 },
61 image_size,
62 );
63 let image = DynamicImage::from_decoder(image).log_err()?.resize(
64 new_bounds.size.width.0 as u32,
65 new_bounds.size.height.0 as u32,
66 image::imageops::FilterType::Triangle,
67 );
68
69 let mut png = Vec::new();
70 image
71 .write_with_encoder(PngEncoder::new(&mut png))
72 .log_err()?;
73
74 base64_encoder.write_all(png.as_slice()).log_err()?;
75 } else {
76 base64_encoder.write_all(data.bytes()).log_err()?;
77 }
78 }
79
80 // SAFETY: The base64 encoder should not produce non-UTF8.
81 let source = unsafe { String::from_utf8_unchecked(base64_image) };
82
83 Some(LanguageModelImage {
84 size: image_size,
85 source: source.into(),
86 })
87 })
88 }
89
90 /// Resolves image into an LLM-ready format (base64).
91 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
92 let image_size = data.size(0);
93
94 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
95 // Convert from BGRA to RGBA.
96 for pixel in bytes.chunks_exact_mut(4) {
97 pixel.swap(2, 0);
98 }
99 let mut image = image::RgbaImage::from_vec(
100 image_size.width.0 as u32,
101 image_size.height.0 as u32,
102 bytes,
103 )
104 .expect("We already know this works");
105
106 // https://docs.anthropic.com/en/docs/build-with-claude/vision
107 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
108 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
109 {
110 let new_bounds = ObjectFit::ScaleDown.get_bounds(
111 gpui::Bounds {
112 origin: point(px(0.0), px(0.0)),
113 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
114 },
115 image_size,
116 );
117
118 image = resize(
119 &image,
120 new_bounds.size.width.0 as u32,
121 new_bounds.size.height.0 as u32,
122 image::imageops::FilterType::Triangle,
123 );
124 }
125
126 let mut png = Vec::new();
127
128 image
129 .write_with_encoder(PngEncoder::new(&mut png))
130 .log_err()?;
131
132 let mut base64_image = Vec::new();
133
134 {
135 let mut base64_encoder = EncoderWriter::new(
136 Cursor::new(&mut base64_image),
137 &base64::engine::general_purpose::STANDARD,
138 );
139
140 base64_encoder.write_all(png.as_slice()).log_err()?;
141 }
142
143 // SAFETY: The base64 encoder should not produce non-UTF8.
144 let source = unsafe { String::from_utf8_unchecked(base64_image) };
145
146 Some(LanguageModelImage {
147 size: image_size,
148 source: source.into(),
149 })
150 }
151
152 pub fn estimate_tokens(&self) -> usize {
153 let width = self.size.width.0.unsigned_abs() as usize;
154 let height = self.size.height.0.unsigned_abs() as usize;
155
156 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
157 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
158 // so this method is more of a rough guess.
159 (width * height) / 750
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
164pub struct LanguageModelToolResult {
165 pub tool_use_id: String,
166 pub is_error: bool,
167 pub content: String,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
171pub enum MessageContent {
172 Text(String),
173 Image(LanguageModelImage),
174 ToolResult(LanguageModelToolResult),
175}
176
177impl From<String> for MessageContent {
178 fn from(value: String) -> Self {
179 MessageContent::Text(value)
180 }
181}
182
183impl From<&str> for MessageContent {
184 fn from(value: &str) -> Self {
185 MessageContent::Text(value.to_string())
186 }
187}
188
189#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
190pub struct LanguageModelRequestMessage {
191 pub role: Role,
192 pub content: Vec<MessageContent>,
193 pub cache: bool,
194}
195
196impl LanguageModelRequestMessage {
197 pub fn string_contents(&self) -> String {
198 let mut string_buffer = String::new();
199 for string in self.content.iter().filter_map(|content| match content {
200 MessageContent::Text(text) => Some(text),
201 MessageContent::Image(_) => None,
202 MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
203 }) {
204 string_buffer.push_str(string.as_str())
205 }
206 string_buffer
207 }
208
209 pub fn contents_empty(&self) -> bool {
210 self.content.is_empty()
211 || self
212 .content
213 .get(0)
214 .map(|content| match content {
215 MessageContent::Text(text) => text.trim().is_empty(),
216 MessageContent::Image(_) => true,
217 MessageContent::ToolResult(tool_result) => {
218 tool_result.content.trim().is_empty()
219 }
220 })
221 .unwrap_or(false)
222 }
223}
224
225#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
226pub struct LanguageModelRequestTool {
227 pub name: String,
228 pub description: String,
229 pub input_schema: serde_json::Value,
230}
231
232#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
233pub struct LanguageModelRequest {
234 pub messages: Vec<LanguageModelRequestMessage>,
235 pub tools: Vec<LanguageModelRequestTool>,
236 pub stop: Vec<String>,
237 pub temperature: f32,
238}
239
240impl LanguageModelRequest {
241 pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
242 open_ai::Request {
243 model,
244 messages: self
245 .messages
246 .into_iter()
247 .map(|msg| match msg.role {
248 Role::User => open_ai::RequestMessage::User {
249 content: msg.string_contents(),
250 },
251 Role::Assistant => open_ai::RequestMessage::Assistant {
252 content: Some(msg.string_contents()),
253 tool_calls: Vec::new(),
254 },
255 Role::System => open_ai::RequestMessage::System {
256 content: msg.string_contents(),
257 },
258 })
259 .collect(),
260 stream: true,
261 stop: self.stop,
262 temperature: self.temperature,
263 max_tokens: max_output_tokens,
264 tools: Vec::new(),
265 tool_choice: None,
266 }
267 }
268
269 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
270 google_ai::GenerateContentRequest {
271 model,
272 contents: self
273 .messages
274 .into_iter()
275 .map(|msg| google_ai::Content {
276 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
277 text: msg.string_contents(),
278 })],
279 role: match msg.role {
280 Role::User => google_ai::Role::User,
281 Role::Assistant => google_ai::Role::Model,
282 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
283 },
284 })
285 .collect(),
286 generation_config: Some(google_ai::GenerationConfig {
287 candidate_count: Some(1),
288 stop_sequences: Some(self.stop),
289 max_output_tokens: None,
290 temperature: Some(self.temperature as f64),
291 top_p: None,
292 top_k: None,
293 }),
294 safety_settings: None,
295 }
296 }
297
298 pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
299 let mut new_messages: Vec<anthropic::Message> = Vec::new();
300 let mut system_message = String::new();
301
302 for message in self.messages {
303 if message.contents_empty() {
304 continue;
305 }
306
307 match message.role {
308 Role::User | Role::Assistant => {
309 let cache_control = if message.cache {
310 Some(anthropic::CacheControl {
311 cache_type: anthropic::CacheControlType::Ephemeral,
312 })
313 } else {
314 None
315 };
316 let anthropic_message_content: Vec<anthropic::RequestContent> = message
317 .content
318 .into_iter()
319 .filter_map(|content| match content {
320 MessageContent::Text(text) => {
321 if !text.is_empty() {
322 Some(anthropic::RequestContent::Text {
323 text,
324 cache_control,
325 })
326 } else {
327 None
328 }
329 }
330 MessageContent::Image(image) => {
331 Some(anthropic::RequestContent::Image {
332 source: anthropic::ImageSource {
333 source_type: "base64".to_string(),
334 media_type: "image/png".to_string(),
335 data: image.source.to_string(),
336 },
337 cache_control,
338 })
339 }
340 MessageContent::ToolResult(tool_result) => {
341 Some(anthropic::RequestContent::ToolResult {
342 tool_use_id: tool_result.tool_use_id,
343 is_error: tool_result.is_error,
344 content: tool_result.content,
345 cache_control,
346 })
347 }
348 })
349 .collect();
350 let anthropic_role = match message.role {
351 Role::User => anthropic::Role::User,
352 Role::Assistant => anthropic::Role::Assistant,
353 Role::System => unreachable!("System role should never occur here"),
354 };
355 if let Some(last_message) = new_messages.last_mut() {
356 if last_message.role == anthropic_role {
357 last_message.content.extend(anthropic_message_content);
358 continue;
359 }
360 }
361 new_messages.push(anthropic::Message {
362 role: anthropic_role,
363 content: anthropic_message_content,
364 });
365 }
366 Role::System => {
367 if !system_message.is_empty() {
368 system_message.push_str("\n\n");
369 }
370 system_message.push_str(&message.string_contents());
371 }
372 }
373 }
374
375 anthropic::Request {
376 model,
377 messages: new_messages,
378 max_tokens: max_output_tokens,
379 system: Some(system_message),
380 tools: self
381 .tools
382 .into_iter()
383 .map(|tool| anthropic::Tool {
384 name: tool.name,
385 description: tool.description,
386 input_schema: tool.input_schema,
387 })
388 .collect(),
389 tool_choice: None,
390 metadata: None,
391 stop_sequences: Vec::new(),
392 temperature: None,
393 top_k: None,
394 top_p: None,
395 }
396 }
397}
398
399#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
400pub struct LanguageModelResponseMessage {
401 pub role: Option<Role>,
402 pub content: Option<String>,
403}