1use anyhow::{Result, anyhow};
2use collections::HashMap;
3use futures::{FutureExt, Stream, future::BoxFuture};
4use gpui::{App, AppContext as _};
5use language_model::{
6 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
7 LanguageModelToolChoice, LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
8};
9use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent};
10pub use settings::OpenAiAvailableModel as AvailableModel;
11use std::pin::Pin;
12use std::str::FromStr;
13
14use language_model::LanguageModelToolResultContent;
15
16#[derive(Default, Clone, Debug, PartialEq)]
17pub struct OpenAiSettings {
18 pub api_url: String,
19 pub available_models: Vec<AvailableModel>,
20}
21
22pub fn into_open_ai(
23 request: LanguageModelRequest,
24 model_id: &str,
25 supports_parallel_tool_calls: bool,
26 supports_prompt_cache_key: bool,
27 max_output_tokens: Option<u64>,
28 reasoning_effort: Option<ReasoningEffort>,
29) -> open_ai::Request {
30 let stream = !model_id.starts_with("o1-");
31
32 let mut messages = Vec::new();
33 for message in request.messages {
34 for content in message.content {
35 match content {
36 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
37 if !text.trim().is_empty() {
38 add_message_content_part(
39 open_ai::MessagePart::Text { text },
40 message.role,
41 &mut messages,
42 );
43 }
44 }
45 MessageContent::RedactedThinking(_) => {}
46 MessageContent::Image(image) => {
47 add_message_content_part(
48 open_ai::MessagePart::Image {
49 image_url: ImageUrl {
50 url: image.to_base64_url(),
51 detail: None,
52 },
53 },
54 message.role,
55 &mut messages,
56 );
57 }
58 MessageContent::ToolUse(tool_use) => {
59 let tool_call = open_ai::ToolCall {
60 id: tool_use.id.to_string(),
61 content: open_ai::ToolCallContent::Function {
62 function: open_ai::FunctionContent {
63 name: tool_use.name.to_string(),
64 arguments: serde_json::to_string(&tool_use.input)
65 .unwrap_or_default(),
66 },
67 },
68 };
69
70 if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
71 messages.last_mut()
72 {
73 tool_calls.push(tool_call);
74 } else {
75 messages.push(open_ai::RequestMessage::Assistant {
76 content: None,
77 tool_calls: vec![tool_call],
78 });
79 }
80 }
81 MessageContent::ToolResult(tool_result) => {
82 let content = match &tool_result.content {
83 LanguageModelToolResultContent::Text(text) => {
84 vec![open_ai::MessagePart::Text {
85 text: text.to_string(),
86 }]
87 }
88 LanguageModelToolResultContent::Image(image) => {
89 vec![open_ai::MessagePart::Image {
90 image_url: ImageUrl {
91 url: image.to_base64_url(),
92 detail: None,
93 },
94 }]
95 }
96 };
97
98 messages.push(open_ai::RequestMessage::Tool {
99 content: content.into(),
100 tool_call_id: tool_result.tool_use_id.to_string(),
101 });
102 }
103 }
104 }
105 }
106
107 open_ai::Request {
108 model: model_id.into(),
109 messages,
110 stream,
111 stop: request.stop,
112 temperature: request.temperature.unwrap_or(1.0),
113 max_completion_tokens: max_output_tokens,
114 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
115 Some(false)
116 } else {
117 None
118 },
119 prompt_cache_key: if supports_prompt_cache_key {
120 request.thread_id
121 } else {
122 None
123 },
124 tools: request
125 .tools
126 .into_iter()
127 .map(|tool| open_ai::ToolDefinition::Function {
128 function: open_ai::FunctionDefinition {
129 name: tool.name,
130 description: Some(tool.description),
131 parameters: Some(tool.input_schema),
132 },
133 })
134 .collect(),
135 tool_choice: request.tool_choice.map(|choice| match choice {
136 LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
137 LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
138 LanguageModelToolChoice::None => open_ai::ToolChoice::None,
139 }),
140 reasoning_effort,
141 }
142}
143
144fn add_message_content_part(
145 new_part: open_ai::MessagePart,
146 role: Role,
147 messages: &mut Vec<open_ai::RequestMessage>,
148) {
149 match (role, messages.last_mut()) {
150 (Role::User, Some(open_ai::RequestMessage::User { content }))
151 | (
152 Role::Assistant,
153 Some(open_ai::RequestMessage::Assistant {
154 content: Some(content),
155 ..
156 }),
157 )
158 | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
159 content.push_part(new_part);
160 }
161 _ => {
162 messages.push(match role {
163 Role::User => open_ai::RequestMessage::User {
164 content: open_ai::MessageContent::from(vec![new_part]),
165 },
166 Role::Assistant => open_ai::RequestMessage::Assistant {
167 content: Some(open_ai::MessageContent::from(vec![new_part])),
168 tool_calls: Vec::new(),
169 },
170 Role::System => open_ai::RequestMessage::System {
171 content: open_ai::MessageContent::from(vec![new_part]),
172 },
173 });
174 }
175 }
176}
177
178pub struct OpenAiEventMapper {
179 tool_calls_by_index: HashMap<usize, RawToolCall>,
180}
181
182impl OpenAiEventMapper {
183 pub fn new() -> Self {
184 Self {
185 tool_calls_by_index: HashMap::default(),
186 }
187 }
188
189 pub fn map_stream(
190 mut self,
191 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
192 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
193 {
194 use futures::StreamExt;
195 events.flat_map(move |event| {
196 futures::stream::iter(match event {
197 Ok(event) => self.map_event(event),
198 Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
199 })
200 })
201 }
202
203 pub fn map_event(
204 &mut self,
205 event: ResponseStreamEvent,
206 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
207 let mut events = Vec::new();
208 if let Some(usage) = event.usage {
209 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
210 input_tokens: usage.prompt_tokens,
211 output_tokens: usage.completion_tokens,
212 cache_creation_input_tokens: 0,
213 cache_read_input_tokens: 0,
214 })));
215 }
216
217 let Some(choice) = event.choices.first() else {
218 return events;
219 };
220
221 if let Some(delta) = choice.delta.as_ref() {
222 if let Some(content) = delta.content.clone() {
223 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
224 }
225
226 if let Some(tool_calls) = delta.tool_calls.as_ref() {
227 for tool_call in tool_calls {
228 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
229
230 if let Some(tool_id) = tool_call.id.clone() {
231 entry.id = tool_id;
232 }
233
234 if let Some(function) = tool_call.function.as_ref() {
235 if let Some(name) = function.name.clone() {
236 entry.name = name;
237 }
238
239 if let Some(arguments) = function.arguments.clone() {
240 entry.arguments.push_str(&arguments);
241 }
242 }
243 }
244 }
245 }
246
247 match choice.finish_reason.as_deref() {
248 Some("stop") => {
249 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
250 }
251 Some("tool_calls") => {
252 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
253 match serde_json::Value::from_str(&tool_call.arguments) {
254 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
255 LanguageModelToolUse {
256 id: tool_call.id.clone().into(),
257 name: tool_call.name.as_str().into(),
258 is_input_complete: true,
259 input,
260 raw_input: tool_call.arguments.clone(),
261 thought_signature: None,
262 },
263 )),
264 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
265 id: tool_call.id.into(),
266 tool_name: tool_call.name.into(),
267 raw_input: tool_call.arguments.clone().into(),
268 json_parse_error: error.to_string(),
269 }),
270 }
271 }));
272
273 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
274 }
275 Some(stop_reason) => {
276 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
277 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
278 }
279 None => {}
280 }
281
282 events
283 }
284}
285
286#[derive(Default)]
287struct RawToolCall {
288 id: String,
289 name: String,
290 arguments: String,
291}
292
293pub(crate) fn collect_tiktoken_messages(
294 request: LanguageModelRequest,
295) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
296 request
297 .messages
298 .into_iter()
299 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
300 role: match message.role {
301 Role::User => "user".into(),
302 Role::Assistant => "assistant".into(),
303 Role::System => "system".into(),
304 },
305 content: Some(message.string_contents()),
306 name: None,
307 function_call: None,
308 })
309 .collect::<Vec<_>>()
310}
311
312pub fn count_open_ai_tokens(
313 request: LanguageModelRequest,
314 model: Model,
315 cx: &App,
316) -> BoxFuture<'static, Result<u64>> {
317 cx.background_spawn(async move {
318 let messages = collect_tiktoken_messages(request);
319 match model {
320 Model::Custom { max_tokens, .. } => {
321 let model = if max_tokens >= 100_000 {
322 "gpt-4o"
323 } else {
324 "gpt-4"
325 };
326 tiktoken_rs::num_tokens_from_messages(model, &messages)
327 }
328 Model::ThreePointFiveTurbo
329 | Model::Four
330 | Model::FourTurbo
331 | Model::FourOmni
332 | Model::FourOmniMini
333 | Model::FourPointOne
334 | Model::FourPointOneMini
335 | Model::FourPointOneNano
336 | Model::O1
337 | Model::O3
338 | Model::O3Mini
339 | Model::O4Mini
340 | Model::Five
341 | Model::FiveMini
342 | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
343 Model::FivePointOne => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages),
344 }
345 .map(|tokens| tokens as u64)
346 })
347 .boxed()
348}
349
350#[cfg(test)]
351mod tests {
352 use gpui::TestAppContext;
353 use language_model::LanguageModelRequestMessage;
354 use strum::IntoEnumIterator;
355
356 use super::*;
357
358 #[gpui::test]
359 fn tiktoken_rs_support(cx: &TestAppContext) {
360 let request = LanguageModelRequest {
361 thread_id: None,
362 prompt_id: None,
363 intent: None,
364 mode: None,
365 messages: vec![LanguageModelRequestMessage {
366 role: Role::User,
367 content: vec![MessageContent::Text("message".into())],
368 cache: false,
369 reasoning_details: None,
370 }],
371 tools: vec![],
372 tool_choice: None,
373 stop: vec![],
374 temperature: None,
375 thinking_allowed: true,
376 };
377
378 for model in Model::iter() {
379 let count = cx
380 .executor()
381 .block(count_open_ai_tokens(
382 request.clone(),
383 model,
384 &cx.app.borrow(),
385 ))
386 .unwrap();
387 assert!(count > 0);
388 }
389 }
390}