open_ai.rs

  1use anyhow::{Result, anyhow};
  2use collections::HashMap;
  3use futures::{FutureExt, Stream, future::BoxFuture};
  4use gpui::{App, AppContext as _};
  5use language_model::{
  6    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  7    LanguageModelToolChoice, LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
  8};
  9use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent};
 10pub use settings::OpenAiAvailableModel as AvailableModel;
 11use std::pin::Pin;
 12use std::str::FromStr;
 13
 14use language_model::LanguageModelToolResultContent;
 15
 16#[derive(Default, Clone, Debug, PartialEq)]
 17pub struct OpenAiSettings {
 18    pub api_url: String,
 19    pub available_models: Vec<AvailableModel>,
 20}
 21
 22pub fn into_open_ai(
 23    request: LanguageModelRequest,
 24    model_id: &str,
 25    supports_parallel_tool_calls: bool,
 26    supports_prompt_cache_key: bool,
 27    max_output_tokens: Option<u64>,
 28    reasoning_effort: Option<ReasoningEffort>,
 29) -> open_ai::Request {
 30    let stream = !model_id.starts_with("o1-");
 31
 32    let mut messages = Vec::new();
 33    for message in request.messages {
 34        for content in message.content {
 35            match content {
 36                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
 37                    if !text.trim().is_empty() {
 38                        add_message_content_part(
 39                            open_ai::MessagePart::Text { text },
 40                            message.role,
 41                            &mut messages,
 42                        );
 43                    }
 44                }
 45                MessageContent::RedactedThinking(_) => {}
 46                MessageContent::Image(image) => {
 47                    add_message_content_part(
 48                        open_ai::MessagePart::Image {
 49                            image_url: ImageUrl {
 50                                url: image.to_base64_url(),
 51                                detail: None,
 52                            },
 53                        },
 54                        message.role,
 55                        &mut messages,
 56                    );
 57                }
 58                MessageContent::ToolUse(tool_use) => {
 59                    let tool_call = open_ai::ToolCall {
 60                        id: tool_use.id.to_string(),
 61                        content: open_ai::ToolCallContent::Function {
 62                            function: open_ai::FunctionContent {
 63                                name: tool_use.name.to_string(),
 64                                arguments: serde_json::to_string(&tool_use.input)
 65                                    .unwrap_or_default(),
 66                            },
 67                        },
 68                    };
 69
 70                    if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
 71                        messages.last_mut()
 72                    {
 73                        tool_calls.push(tool_call);
 74                    } else {
 75                        messages.push(open_ai::RequestMessage::Assistant {
 76                            content: None,
 77                            tool_calls: vec![tool_call],
 78                        });
 79                    }
 80                }
 81                MessageContent::ToolResult(tool_result) => {
 82                    let content = match &tool_result.content {
 83                        LanguageModelToolResultContent::Text(text) => {
 84                            vec![open_ai::MessagePart::Text {
 85                                text: text.to_string(),
 86                            }]
 87                        }
 88                        LanguageModelToolResultContent::Image(image) => {
 89                            vec![open_ai::MessagePart::Image {
 90                                image_url: ImageUrl {
 91                                    url: image.to_base64_url(),
 92                                    detail: None,
 93                                },
 94                            }]
 95                        }
 96                    };
 97
 98                    messages.push(open_ai::RequestMessage::Tool {
 99                        content: content.into(),
100                        tool_call_id: tool_result.tool_use_id.to_string(),
101                    });
102                }
103            }
104        }
105    }
106
107    open_ai::Request {
108        model: model_id.into(),
109        messages,
110        stream,
111        stop: request.stop,
112        temperature: request.temperature.unwrap_or(1.0),
113        max_completion_tokens: max_output_tokens,
114        parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
115            Some(false)
116        } else {
117            None
118        },
119        prompt_cache_key: if supports_prompt_cache_key {
120            request.thread_id
121        } else {
122            None
123        },
124        tools: request
125            .tools
126            .into_iter()
127            .map(|tool| open_ai::ToolDefinition::Function {
128                function: open_ai::FunctionDefinition {
129                    name: tool.name,
130                    description: Some(tool.description),
131                    parameters: Some(tool.input_schema),
132                },
133            })
134            .collect(),
135        tool_choice: request.tool_choice.map(|choice| match choice {
136            LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
137            LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
138            LanguageModelToolChoice::None => open_ai::ToolChoice::None,
139        }),
140        reasoning_effort,
141    }
142}
143
144fn add_message_content_part(
145    new_part: open_ai::MessagePart,
146    role: Role,
147    messages: &mut Vec<open_ai::RequestMessage>,
148) {
149    match (role, messages.last_mut()) {
150        (Role::User, Some(open_ai::RequestMessage::User { content }))
151        | (
152            Role::Assistant,
153            Some(open_ai::RequestMessage::Assistant {
154                content: Some(content),
155                ..
156            }),
157        )
158        | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
159            content.push_part(new_part);
160        }
161        _ => {
162            messages.push(match role {
163                Role::User => open_ai::RequestMessage::User {
164                    content: open_ai::MessageContent::from(vec![new_part]),
165                },
166                Role::Assistant => open_ai::RequestMessage::Assistant {
167                    content: Some(open_ai::MessageContent::from(vec![new_part])),
168                    tool_calls: Vec::new(),
169                },
170                Role::System => open_ai::RequestMessage::System {
171                    content: open_ai::MessageContent::from(vec![new_part]),
172                },
173            });
174        }
175    }
176}
177
178pub struct OpenAiEventMapper {
179    tool_calls_by_index: HashMap<usize, RawToolCall>,
180}
181
182impl OpenAiEventMapper {
183    pub fn new() -> Self {
184        Self {
185            tool_calls_by_index: HashMap::default(),
186        }
187    }
188
189    pub fn map_stream(
190        mut self,
191        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
192    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
193    {
194        use futures::StreamExt;
195        events.flat_map(move |event| {
196            futures::stream::iter(match event {
197                Ok(event) => self.map_event(event),
198                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
199            })
200        })
201    }
202
203    pub fn map_event(
204        &mut self,
205        event: ResponseStreamEvent,
206    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
207        let mut events = Vec::new();
208        if let Some(usage) = event.usage {
209            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
210                input_tokens: usage.prompt_tokens,
211                output_tokens: usage.completion_tokens,
212                cache_creation_input_tokens: 0,
213                cache_read_input_tokens: 0,
214            })));
215        }
216
217        let Some(choice) = event.choices.first() else {
218            return events;
219        };
220
221        if let Some(delta) = choice.delta.as_ref() {
222            if let Some(content) = delta.content.clone() {
223                events.push(Ok(LanguageModelCompletionEvent::Text(content)));
224            }
225
226            if let Some(tool_calls) = delta.tool_calls.as_ref() {
227                for tool_call in tool_calls {
228                    let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
229
230                    if let Some(tool_id) = tool_call.id.clone() {
231                        entry.id = tool_id;
232                    }
233
234                    if let Some(function) = tool_call.function.as_ref() {
235                        if let Some(name) = function.name.clone() {
236                            entry.name = name;
237                        }
238
239                        if let Some(arguments) = function.arguments.clone() {
240                            entry.arguments.push_str(&arguments);
241                        }
242                    }
243                }
244            }
245        }
246
247        match choice.finish_reason.as_deref() {
248            Some("stop") => {
249                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
250            }
251            Some("tool_calls") => {
252                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
253                    match serde_json::Value::from_str(&tool_call.arguments) {
254                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
255                            LanguageModelToolUse {
256                                id: tool_call.id.clone().into(),
257                                name: tool_call.name.as_str().into(),
258                                is_input_complete: true,
259                                input,
260                                raw_input: tool_call.arguments.clone(),
261                                thought_signature: None,
262                            },
263                        )),
264                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
265                            id: tool_call.id.into(),
266                            tool_name: tool_call.name.into(),
267                            raw_input: tool_call.arguments.clone().into(),
268                            json_parse_error: error.to_string(),
269                        }),
270                    }
271                }));
272
273                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
274            }
275            Some(stop_reason) => {
276                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
277                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
278            }
279            None => {}
280        }
281
282        events
283    }
284}
285
286#[derive(Default)]
287struct RawToolCall {
288    id: String,
289    name: String,
290    arguments: String,
291}
292
293pub(crate) fn collect_tiktoken_messages(
294    request: LanguageModelRequest,
295) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
296    request
297        .messages
298        .into_iter()
299        .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
300            role: match message.role {
301                Role::User => "user".into(),
302                Role::Assistant => "assistant".into(),
303                Role::System => "system".into(),
304            },
305            content: Some(message.string_contents()),
306            name: None,
307            function_call: None,
308        })
309        .collect::<Vec<_>>()
310}
311
312pub fn count_open_ai_tokens(
313    request: LanguageModelRequest,
314    model: Model,
315    cx: &App,
316) -> BoxFuture<'static, Result<u64>> {
317    cx.background_spawn(async move {
318        let messages = collect_tiktoken_messages(request);
319        match model {
320            Model::Custom { max_tokens, .. } => {
321                let model = if max_tokens >= 100_000 {
322                    "gpt-4o"
323                } else {
324                    "gpt-4"
325                };
326                tiktoken_rs::num_tokens_from_messages(model, &messages)
327            }
328            Model::ThreePointFiveTurbo
329            | Model::Four
330            | Model::FourTurbo
331            | Model::FourOmni
332            | Model::FourOmniMini
333            | Model::FourPointOne
334            | Model::FourPointOneMini
335            | Model::FourPointOneNano
336            | Model::O1
337            | Model::O3
338            | Model::O3Mini
339            | Model::O4Mini
340            | Model::Five
341            | Model::FiveMini
342            | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
343            Model::FivePointOne => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages),
344        }
345        .map(|tokens| tokens as u64)
346    })
347    .boxed()
348}
349
350#[cfg(test)]
351mod tests {
352    use gpui::TestAppContext;
353    use language_model::LanguageModelRequestMessage;
354    use strum::IntoEnumIterator;
355
356    use super::*;
357
358    #[gpui::test]
359    fn tiktoken_rs_support(cx: &TestAppContext) {
360        let request = LanguageModelRequest {
361            thread_id: None,
362            prompt_id: None,
363            intent: None,
364            mode: None,
365            messages: vec![LanguageModelRequestMessage {
366                role: Role::User,
367                content: vec![MessageContent::Text("message".into())],
368                cache: false,
369                reasoning_details: None,
370            }],
371            tools: vec![],
372            tool_choice: None,
373            stop: vec![],
374            temperature: None,
375            thinking_allowed: true,
376        };
377
378        for model in Model::iter() {
379            let count = cx
380                .executor()
381                .block(count_open_ai_tokens(
382                    request.clone(),
383                    model,
384                    &cx.app.borrow(),
385                ))
386                .unwrap();
387            assert!(count > 0);
388        }
389    }
390}