1use std::io::{Cursor, Write};
2
3use crate::role::Role;
4use base64::write::EncoderWriter;
5use gpui::{point, size, AppContext, DevicePixels, Image, ObjectFit, RenderImage, Size, Task};
6use image::{codecs::png::PngEncoder, imageops::resize, DynamicImage, ImageDecoder};
7use serde::{Deserialize, Serialize};
8use ui::{px, SharedString};
9use util::ResultExt;
10
11#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
12pub struct LanguageModelImage {
13 /// A base64-encoded PNG image.
14 pub source: SharedString,
15 size: Size<DevicePixels>,
16}
17
18impl std::fmt::Debug for LanguageModelImage {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("LanguageModelImage")
21 .field("source", &format!("<{} bytes>", self.source.len()))
22 .field("size", &self.size)
23 .finish()
24 }
25}
26
27/// Anthropic wants uploaded images to be smaller than this in both dimensions.
28const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
29
30impl LanguageModelImage {
31 pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
32 cx.background_executor().spawn(async move {
33 match data.format() {
34 gpui::ImageFormat::Png
35 | gpui::ImageFormat::Jpeg
36 | gpui::ImageFormat::Webp
37 | gpui::ImageFormat::Gif => {}
38 _ => return None,
39 };
40
41 let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
42 let (width, height) = image.dimensions();
43 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
44
45 let mut base64_image = Vec::new();
46
47 {
48 let mut base64_encoder = EncoderWriter::new(
49 Cursor::new(&mut base64_image),
50 &base64::engine::general_purpose::STANDARD,
51 );
52
53 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
54 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
55 {
56 let new_bounds = ObjectFit::ScaleDown.get_bounds(
57 gpui::Bounds {
58 origin: point(px(0.0), px(0.0)),
59 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
60 },
61 image_size,
62 );
63 let image = DynamicImage::from_decoder(image).log_err()?.resize(
64 new_bounds.size.width.0 as u32,
65 new_bounds.size.height.0 as u32,
66 image::imageops::FilterType::Triangle,
67 );
68
69 let mut png = Vec::new();
70 image
71 .write_with_encoder(PngEncoder::new(&mut png))
72 .log_err()?;
73
74 base64_encoder.write_all(png.as_slice()).log_err()?;
75 } else {
76 base64_encoder.write_all(data.bytes()).log_err()?;
77 }
78 }
79
80 // SAFETY: The base64 encoder should not produce non-UTF8.
81 let source = unsafe { String::from_utf8_unchecked(base64_image) };
82
83 Some(LanguageModelImage {
84 size: image_size,
85 source: source.into(),
86 })
87 })
88 }
89
90 /// Resolves image into an LLM-ready format (base64).
91 pub fn from_render_image(data: &RenderImage) -> Option<Self> {
92 let image_size = data.size(0);
93
94 let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
95 // Convert from BGRA to RGBA.
96 for pixel in bytes.chunks_exact_mut(4) {
97 pixel.swap(2, 0);
98 }
99 let mut image = image::RgbaImage::from_vec(
100 image_size.width.0 as u32,
101 image_size.height.0 as u32,
102 bytes,
103 )
104 .expect("We already know this works");
105
106 // https://docs.anthropic.com/en/docs/build-with-claude/vision
107 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
108 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
109 {
110 let new_bounds = ObjectFit::ScaleDown.get_bounds(
111 gpui::Bounds {
112 origin: point(px(0.0), px(0.0)),
113 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
114 },
115 image_size,
116 );
117
118 image = resize(
119 &image,
120 new_bounds.size.width.0 as u32,
121 new_bounds.size.height.0 as u32,
122 image::imageops::FilterType::Triangle,
123 );
124 }
125
126 let mut png = Vec::new();
127
128 image
129 .write_with_encoder(PngEncoder::new(&mut png))
130 .log_err()?;
131
132 let mut base64_image = Vec::new();
133
134 {
135 let mut base64_encoder = EncoderWriter::new(
136 Cursor::new(&mut base64_image),
137 &base64::engine::general_purpose::STANDARD,
138 );
139
140 base64_encoder.write_all(png.as_slice()).log_err()?;
141 }
142
143 // SAFETY: The base64 encoder should not produce non-UTF8.
144 let source = unsafe { String::from_utf8_unchecked(base64_image) };
145
146 Some(LanguageModelImage {
147 size: image_size,
148 source: source.into(),
149 })
150 }
151
152 pub fn estimate_tokens(&self) -> usize {
153 let width = self.size.width.0.unsigned_abs() as usize;
154 let height = self.size.height.0.unsigned_abs() as usize;
155
156 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
157 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
158 // so this method is more of a rough guess.
159 (width * height) / 750
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
164pub struct LanguageModelToolResult {
165 pub tool_use_id: String,
166 pub is_error: bool,
167 pub content: String,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
171pub enum MessageContent {
172 Text(String),
173 Image(LanguageModelImage),
174 ToolResult(LanguageModelToolResult),
175}
176
177impl MessageContent {
178 pub fn as_string(&self) -> &str {
179 match self {
180 MessageContent::Text(text) => text.as_str(),
181 MessageContent::Image(_) => "",
182 MessageContent::ToolResult(tool_result) => tool_result.content.as_str(),
183 }
184 }
185}
186
187impl From<String> for MessageContent {
188 fn from(value: String) -> Self {
189 MessageContent::Text(value)
190 }
191}
192
193impl From<&str> for MessageContent {
194 fn from(value: &str) -> Self {
195 MessageContent::Text(value.to_string())
196 }
197}
198
199#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
200pub struct LanguageModelRequestMessage {
201 pub role: Role,
202 pub content: Vec<MessageContent>,
203 pub cache: bool,
204}
205
206impl LanguageModelRequestMessage {
207 pub fn string_contents(&self) -> String {
208 let mut string_buffer = String::new();
209 for string in self.content.iter().filter_map(|content| match content {
210 MessageContent::Text(text) => Some(text),
211 MessageContent::Image(_) => None,
212 MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
213 }) {
214 string_buffer.push_str(string.as_str())
215 }
216 string_buffer
217 }
218
219 pub fn contents_empty(&self) -> bool {
220 self.content.is_empty()
221 || self
222 .content
223 .get(0)
224 .map(|content| match content {
225 MessageContent::Text(text) => text.trim().is_empty(),
226 MessageContent::Image(_) => true,
227 MessageContent::ToolResult(tool_result) => {
228 tool_result.content.trim().is_empty()
229 }
230 })
231 .unwrap_or(false)
232 }
233}
234
235#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
236pub struct LanguageModelRequestTool {
237 pub name: String,
238 pub description: String,
239 pub input_schema: serde_json::Value,
240}
241
242#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
243pub struct LanguageModelRequest {
244 pub messages: Vec<LanguageModelRequestMessage>,
245 pub tools: Vec<LanguageModelRequestTool>,
246 pub stop: Vec<String>,
247 pub temperature: f32,
248}
249
250impl LanguageModelRequest {
251 pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
252 open_ai::Request {
253 model,
254 messages: self
255 .messages
256 .into_iter()
257 .map(|msg| match msg.role {
258 Role::User => open_ai::RequestMessage::User {
259 content: msg.string_contents(),
260 },
261 Role::Assistant => open_ai::RequestMessage::Assistant {
262 content: Some(msg.string_contents()),
263 tool_calls: Vec::new(),
264 },
265 Role::System => open_ai::RequestMessage::System {
266 content: msg.string_contents(),
267 },
268 })
269 .collect(),
270 stream: true,
271 stop: self.stop,
272 temperature: self.temperature,
273 max_tokens: max_output_tokens,
274 tools: Vec::new(),
275 tool_choice: None,
276 }
277 }
278
279 pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
280 google_ai::GenerateContentRequest {
281 model,
282 contents: self
283 .messages
284 .into_iter()
285 .map(|msg| google_ai::Content {
286 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
287 text: msg.string_contents(),
288 })],
289 role: match msg.role {
290 Role::User => google_ai::Role::User,
291 Role::Assistant => google_ai::Role::Model,
292 Role::System => google_ai::Role::User, // Google AI doesn't have a system role
293 },
294 })
295 .collect(),
296 generation_config: Some(google_ai::GenerationConfig {
297 candidate_count: Some(1),
298 stop_sequences: Some(self.stop),
299 max_output_tokens: None,
300 temperature: Some(self.temperature as f64),
301 top_p: None,
302 top_k: None,
303 }),
304 safety_settings: None,
305 }
306 }
307
308 pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
309 let mut new_messages: Vec<anthropic::Message> = Vec::new();
310 let mut system_message = String::new();
311
312 for message in self.messages {
313 if message.contents_empty() {
314 continue;
315 }
316
317 match message.role {
318 Role::User | Role::Assistant => {
319 let cache_control = if message.cache {
320 Some(anthropic::CacheControl {
321 cache_type: anthropic::CacheControlType::Ephemeral,
322 })
323 } else {
324 None
325 };
326 let anthropic_message_content: Vec<anthropic::RequestContent> = message
327 .content
328 .into_iter()
329 .filter_map(|content| match content {
330 MessageContent::Text(text) => {
331 if !text.is_empty() {
332 Some(anthropic::RequestContent::Text {
333 text,
334 cache_control,
335 })
336 } else {
337 None
338 }
339 }
340 MessageContent::Image(image) => {
341 Some(anthropic::RequestContent::Image {
342 source: anthropic::ImageSource {
343 source_type: "base64".to_string(),
344 media_type: "image/png".to_string(),
345 data: image.source.to_string(),
346 },
347 cache_control,
348 })
349 }
350 MessageContent::ToolResult(tool_result) => {
351 Some(anthropic::RequestContent::ToolResult {
352 tool_use_id: tool_result.tool_use_id,
353 is_error: tool_result.is_error,
354 content: tool_result.content,
355 cache_control,
356 })
357 }
358 })
359 .collect();
360 let anthropic_role = match message.role {
361 Role::User => anthropic::Role::User,
362 Role::Assistant => anthropic::Role::Assistant,
363 Role::System => unreachable!("System role should never occur here"),
364 };
365 if let Some(last_message) = new_messages.last_mut() {
366 if last_message.role == anthropic_role {
367 last_message.content.extend(anthropic_message_content);
368 continue;
369 }
370 }
371 new_messages.push(anthropic::Message {
372 role: anthropic_role,
373 content: anthropic_message_content,
374 });
375 }
376 Role::System => {
377 if !system_message.is_empty() {
378 system_message.push_str("\n\n");
379 }
380 system_message.push_str(&message.string_contents());
381 }
382 }
383 }
384
385 anthropic::Request {
386 model,
387 messages: new_messages,
388 max_tokens: max_output_tokens,
389 system: Some(system_message),
390 tools: self
391 .tools
392 .into_iter()
393 .map(|tool| anthropic::Tool {
394 name: tool.name,
395 description: tool.description,
396 input_schema: tool.input_schema,
397 })
398 .collect(),
399 tool_choice: None,
400 metadata: None,
401 stop_sequences: Vec::new(),
402 temperature: None,
403 top_k: None,
404 top_p: None,
405 }
406 }
407}
408
409#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
410pub struct LanguageModelResponseMessage {
411 pub role: Option<Role>,
412 pub content: Option<String>,
413}