openai.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3
  4use serde::{Deserialize, Serialize};
  5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  6use zed_extension_api::{self as zed, *};
  7
  8struct OpenAiProvider {
  9    streams: Mutex<HashMap<String, StreamState>>,
 10    next_stream_id: Mutex<u64>,
 11}
 12
 13struct StreamState {
 14    response_stream: Option<HttpResponseStream>,
 15    buffer: String,
 16    started: bool,
 17    tool_calls: HashMap<usize, AccumulatedToolCall>,
 18    tool_calls_emitted: bool,
 19}
 20
 21#[derive(Clone, Default)]
 22struct AccumulatedToolCall {
 23    id: String,
 24    name: String,
 25    arguments: String,
 26}
 27
 28struct ModelDefinition {
 29    real_id: &'static str,
 30    display_name: &'static str,
 31    max_tokens: u64,
 32    max_output_tokens: Option<u64>,
 33    supports_images: bool,
 34    is_default: bool,
 35    is_default_fast: bool,
 36}
 37
 38const MODELS: &[ModelDefinition] = &[
 39    ModelDefinition {
 40        real_id: "gpt-4o",
 41        display_name: "GPT-4o",
 42        max_tokens: 128_000,
 43        max_output_tokens: Some(16_384),
 44        supports_images: true,
 45        is_default: true,
 46        is_default_fast: false,
 47    },
 48    ModelDefinition {
 49        real_id: "gpt-4o-mini",
 50        display_name: "GPT-4o-mini",
 51        max_tokens: 128_000,
 52        max_output_tokens: Some(16_384),
 53        supports_images: true,
 54        is_default: false,
 55        is_default_fast: true,
 56    },
 57    ModelDefinition {
 58        real_id: "gpt-4.1",
 59        display_name: "GPT-4.1",
 60        max_tokens: 1_047_576,
 61        max_output_tokens: Some(32_768),
 62        supports_images: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gpt-4.1-mini",
 68        display_name: "GPT-4.1-mini",
 69        max_tokens: 1_047_576,
 70        max_output_tokens: Some(32_768),
 71        supports_images: true,
 72        is_default: false,
 73        is_default_fast: false,
 74    },
 75    ModelDefinition {
 76        real_id: "gpt-4.1-nano",
 77        display_name: "GPT-4.1-nano",
 78        max_tokens: 1_047_576,
 79        max_output_tokens: Some(32_768),
 80        supports_images: true,
 81        is_default: false,
 82        is_default_fast: false,
 83    },
 84    ModelDefinition {
 85        real_id: "gpt-5",
 86        display_name: "GPT-5",
 87        max_tokens: 272_000,
 88        max_output_tokens: Some(32_768),
 89        supports_images: true,
 90        is_default: false,
 91        is_default_fast: false,
 92    },
 93    ModelDefinition {
 94        real_id: "gpt-5-mini",
 95        display_name: "GPT-5-mini",
 96        max_tokens: 272_000,
 97        max_output_tokens: Some(32_768),
 98        supports_images: true,
 99        is_default: false,
100        is_default_fast: false,
101    },
102    ModelDefinition {
103        real_id: "o1",
104        display_name: "o1",
105        max_tokens: 200_000,
106        max_output_tokens: Some(100_000),
107        supports_images: true,
108        is_default: false,
109        is_default_fast: false,
110    },
111    ModelDefinition {
112        real_id: "o3",
113        display_name: "o3",
114        max_tokens: 200_000,
115        max_output_tokens: Some(100_000),
116        supports_images: true,
117        is_default: false,
118        is_default_fast: false,
119    },
120    ModelDefinition {
121        real_id: "o3-mini",
122        display_name: "o3-mini",
123        max_tokens: 200_000,
124        max_output_tokens: Some(100_000),
125        supports_images: false,
126        is_default: false,
127        is_default_fast: false,
128    },
129    ModelDefinition {
130        real_id: "o4-mini",
131        display_name: "o4-mini",
132        max_tokens: 200_000,
133        max_output_tokens: Some(100_000),
134        supports_images: true,
135        is_default: false,
136        is_default_fast: false,
137    },
138];
139
140fn get_real_model_id(display_name: &str) -> Option<&'static str> {
141    MODELS
142        .iter()
143        .find(|m| m.display_name == display_name)
144        .map(|m| m.real_id)
145}
146
147#[derive(Serialize)]
148struct OpenAiRequest {
149    model: String,
150    messages: Vec<OpenAiMessage>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    tools: Option<Vec<OpenAiTool>>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    tool_choice: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    temperature: Option<f32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    max_tokens: Option<u64>,
159    #[serde(skip_serializing_if = "Vec::is_empty")]
160    stop: Vec<String>,
161    stream: bool,
162    stream_options: Option<StreamOptions>,
163}
164
165#[derive(Serialize)]
166struct StreamOptions {
167    include_usage: bool,
168}
169
170#[derive(Serialize)]
171#[serde(tag = "role")]
172enum OpenAiMessage {
173    #[serde(rename = "system")]
174    System { content: String },
175    #[serde(rename = "user")]
176    User { content: Vec<OpenAiContentPart> },
177    #[serde(rename = "assistant")]
178    Assistant {
179        #[serde(skip_serializing_if = "Option::is_none")]
180        content: Option<String>,
181        #[serde(skip_serializing_if = "Option::is_none")]
182        tool_calls: Option<Vec<OpenAiToolCall>>,
183    },
184    #[serde(rename = "tool")]
185    Tool {
186        tool_call_id: String,
187        content: String,
188    },
189}
190
191#[derive(Serialize)]
192#[serde(tag = "type")]
193enum OpenAiContentPart {
194    #[serde(rename = "text")]
195    Text { text: String },
196    #[serde(rename = "image_url")]
197    ImageUrl { image_url: ImageUrl },
198}
199
200#[derive(Serialize)]
201struct ImageUrl {
202    url: String,
203}
204
205#[derive(Serialize, Deserialize, Clone)]
206struct OpenAiToolCall {
207    id: String,
208    #[serde(rename = "type")]
209    call_type: String,
210    function: OpenAiFunctionCall,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214struct OpenAiFunctionCall {
215    name: String,
216    arguments: String,
217}
218
219#[derive(Serialize)]
220struct OpenAiTool {
221    #[serde(rename = "type")]
222    tool_type: String,
223    function: OpenAiFunctionDef,
224}
225
226#[derive(Serialize)]
227struct OpenAiFunctionDef {
228    name: String,
229    description: String,
230    parameters: serde_json::Value,
231}
232
233#[derive(Deserialize, Debug)]
234struct OpenAiStreamEvent {
235    choices: Vec<OpenAiChoice>,
236    #[serde(default)]
237    usage: Option<OpenAiUsage>,
238}
239
240#[derive(Deserialize, Debug)]
241struct OpenAiChoice {
242    delta: OpenAiDelta,
243    finish_reason: Option<String>,
244}
245
246#[derive(Deserialize, Debug, Default)]
247struct OpenAiDelta {
248    #[serde(default)]
249    content: Option<String>,
250    #[serde(default)]
251    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
252}
253
254#[derive(Deserialize, Debug)]
255struct OpenAiToolCallDelta {
256    index: usize,
257    #[serde(default)]
258    id: Option<String>,
259    #[serde(default)]
260    function: Option<OpenAiFunctionDelta>,
261}
262
263#[derive(Deserialize, Debug)]
264struct OpenAiFunctionDelta {
265    #[serde(default)]
266    name: Option<String>,
267    #[serde(default)]
268    arguments: Option<String>,
269}
270
271#[derive(Deserialize, Debug)]
272struct OpenAiUsage {
273    prompt_tokens: u64,
274    completion_tokens: u64,
275}
276
277#[allow(dead_code)]
278#[derive(Deserialize, Debug)]
279struct OpenAiError {
280    error: OpenAiErrorDetail,
281}
282
283#[allow(dead_code)]
284#[derive(Deserialize, Debug)]
285struct OpenAiErrorDetail {
286    message: String,
287}
288
289fn convert_request(
290    model_id: &str,
291    request: &LlmCompletionRequest,
292) -> Result<OpenAiRequest, String> {
293    let real_model_id =
294        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
295
296    let mut messages = Vec::new();
297
298    for msg in &request.messages {
299        match msg.role {
300            LlmMessageRole::System => {
301                let text: String = msg
302                    .content
303                    .iter()
304                    .filter_map(|c| match c {
305                        LlmMessageContent::Text(t) => Some(t.as_str()),
306                        _ => None,
307                    })
308                    .collect::<Vec<_>>()
309                    .join("\n");
310                if !text.is_empty() {
311                    messages.push(OpenAiMessage::System { content: text });
312                }
313            }
314            LlmMessageRole::User => {
315                let parts: Vec<OpenAiContentPart> = msg
316                    .content
317                    .iter()
318                    .filter_map(|c| match c {
319                        LlmMessageContent::Text(t) => {
320                            Some(OpenAiContentPart::Text { text: t.clone() })
321                        }
322                        LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl {
323                            image_url: ImageUrl {
324                                url: format!("data:image/png;base64,{}", img.source),
325                            },
326                        }),
327                        LlmMessageContent::ToolResult(_) => None,
328                        _ => None,
329                    })
330                    .collect();
331
332                for content in &msg.content {
333                    if let LlmMessageContent::ToolResult(result) = content {
334                        let content_text = match &result.content {
335                            LlmToolResultContent::Text(t) => t.clone(),
336                            LlmToolResultContent::Image(_) => "[Image]".to_string(),
337                        };
338                        messages.push(OpenAiMessage::Tool {
339                            tool_call_id: result.tool_use_id.clone(),
340                            content: content_text,
341                        });
342                    }
343                }
344
345                if !parts.is_empty() {
346                    messages.push(OpenAiMessage::User { content: parts });
347                }
348            }
349            LlmMessageRole::Assistant => {
350                let mut content_text: Option<String> = None;
351                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
352
353                for c in &msg.content {
354                    match c {
355                        LlmMessageContent::Text(t) => {
356                            content_text = Some(t.clone());
357                        }
358                        LlmMessageContent::ToolUse(tool_use) => {
359                            tool_calls.push(OpenAiToolCall {
360                                id: tool_use.id.clone(),
361                                call_type: "function".to_string(),
362                                function: OpenAiFunctionCall {
363                                    name: tool_use.name.clone(),
364                                    arguments: tool_use.input.clone(),
365                                },
366                            });
367                        }
368                        _ => {}
369                    }
370                }
371
372                messages.push(OpenAiMessage::Assistant {
373                    content: content_text,
374                    tool_calls: if tool_calls.is_empty() {
375                        None
376                    } else {
377                        Some(tool_calls)
378                    },
379                });
380            }
381        }
382    }
383
384    let tools: Option<Vec<OpenAiTool>> = if request.tools.is_empty() {
385        None
386    } else {
387        Some(
388            request
389                .tools
390                .iter()
391                .map(|t| OpenAiTool {
392                    tool_type: "function".to_string(),
393                    function: OpenAiFunctionDef {
394                        name: t.name.clone(),
395                        description: t.description.clone(),
396                        parameters: serde_json::from_str(&t.input_schema)
397                            .unwrap_or(serde_json::Value::Object(Default::default())),
398                    },
399                })
400                .collect(),
401        )
402    };
403
404    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
405        LlmToolChoice::Auto => "auto".to_string(),
406        LlmToolChoice::Any => "required".to_string(),
407        LlmToolChoice::None => "none".to_string(),
408    });
409
410    Ok(OpenAiRequest {
411        model: real_model_id.to_string(),
412        messages,
413        tools,
414        tool_choice,
415        temperature: request.temperature,
416        max_tokens: request.max_tokens,
417        stop: request.stop_sequences.clone(),
418        stream: true,
419        stream_options: Some(StreamOptions {
420            include_usage: true,
421        }),
422    })
423}
424
425fn parse_sse_line(line: &str) -> Option<OpenAiStreamEvent> {
426    if let Some(data) = line.strip_prefix("data: ") {
427        if data == "[DONE]" {
428            return None;
429        }
430        serde_json::from_str(data).ok()
431    } else {
432        None
433    }
434}
435
436impl zed::Extension for OpenAiProvider {
437    fn new() -> Self {
438        Self {
439            streams: Mutex::new(HashMap::new()),
440            next_stream_id: Mutex::new(0),
441        }
442    }
443
444    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
445        vec![LlmProviderInfo {
446            id: "openai".into(),
447            name: "OpenAI".into(),
448            icon: Some("icons/openai.svg".into()),
449        }]
450    }
451
452    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
453        Ok(MODELS
454            .iter()
455            .map(|m| LlmModelInfo {
456                id: m.display_name.to_string(),
457                name: m.display_name.to_string(),
458                max_token_count: m.max_tokens,
459                max_output_tokens: m.max_output_tokens,
460                capabilities: LlmModelCapabilities {
461                    supports_images: m.supports_images,
462                    supports_tools: true,
463                    supports_tool_choice_auto: true,
464                    supports_tool_choice_any: true,
465                    supports_tool_choice_none: true,
466                    supports_thinking: false,
467                    tool_input_format: LlmToolInputFormat::JsonSchema,
468                },
469                is_default: m.is_default,
470                is_default_fast: m.is_default_fast,
471            })
472            .collect())
473    }
474
475    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
476        llm_get_credential("openai").is_some()
477    }
478
479    fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
480        if llm_get_credential("openai").is_some() {
481            Ok(())
482        } else {
483            Err("No API key configured".to_string())
484        }
485    }
486
487    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
488        Some(
489            "To use OpenAI, you need an API key. You can create one [here](https://platform.openai.com/api-keys).".to_string(),
490        )
491    }
492
493    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
494        llm_delete_credential("openai")
495    }
496
497    fn llm_stream_completion_start(
498        &mut self,
499        _provider_id: &str,
500        model_id: &str,
501        request: &LlmCompletionRequest,
502    ) -> Result<String, String> {
503        let api_key = llm_get_credential("openai").ok_or_else(|| {
504            "No API key configured. Please add your OpenAI API key in settings.".to_string()
505        })?;
506
507        let openai_request = convert_request(model_id, request)?;
508
509        let body = serde_json::to_vec(&openai_request)
510            .map_err(|e| format!("Failed to serialize request: {}", e))?;
511
512        let http_request = HttpRequest {
513            method: HttpMethod::Post,
514            url: "https://api.openai.com/v1/chat/completions".to_string(),
515            headers: vec![
516                ("Content-Type".to_string(), "application/json".to_string()),
517                ("Authorization".to_string(), format!("Bearer {}", api_key)),
518            ],
519            body: Some(body),
520            redirect_policy: RedirectPolicy::FollowAll,
521        };
522
523        let response_stream = http_request
524            .fetch_stream()
525            .map_err(|e| format!("HTTP request failed: {}", e))?;
526
527        let stream_id = {
528            let mut id_counter = self.next_stream_id.lock().unwrap();
529            let id = format!("openai-stream-{}", *id_counter);
530            *id_counter += 1;
531            id
532        };
533
534        self.streams.lock().unwrap().insert(
535            stream_id.clone(),
536            StreamState {
537                response_stream: Some(response_stream),
538                buffer: String::new(),
539                started: false,
540                tool_calls: HashMap::new(),
541                tool_calls_emitted: false,
542            },
543        );
544
545        Ok(stream_id)
546    }
547
548    fn llm_stream_completion_next(
549        &mut self,
550        stream_id: &str,
551    ) -> Result<Option<LlmCompletionEvent>, String> {
552        let mut streams = self.streams.lock().unwrap();
553        let state = streams
554            .get_mut(stream_id)
555            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
556
557        if !state.started {
558            state.started = true;
559            return Ok(Some(LlmCompletionEvent::Started));
560        }
561
562        let response_stream = state
563            .response_stream
564            .as_mut()
565            .ok_or_else(|| "Stream already closed".to_string())?;
566
567        loop {
568            if let Some(newline_pos) = state.buffer.find('\n') {
569                let line = state.buffer[..newline_pos].trim().to_string();
570                state.buffer = state.buffer[newline_pos + 1..].to_string();
571
572                if line.is_empty() {
573                    continue;
574                }
575
576                if let Some(event) = parse_sse_line(&line) {
577                    if let Some(choice) = event.choices.first() {
578                        if let Some(tool_calls) = &choice.delta.tool_calls {
579                            for tc in tool_calls {
580                                let entry = state.tool_calls.entry(tc.index).or_default();
581
582                                if let Some(id) = &tc.id {
583                                    entry.id = id.clone();
584                                }
585
586                                if let Some(func) = &tc.function {
587                                    if let Some(name) = &func.name {
588                                        entry.name = name.clone();
589                                    }
590                                    if let Some(args) = &func.arguments {
591                                        entry.arguments.push_str(args);
592                                    }
593                                }
594                            }
595                        }
596
597                        if let Some(reason) = &choice.finish_reason {
598                            if reason == "tool_calls" && !state.tool_calls_emitted {
599                                state.tool_calls_emitted = true;
600                                if let Some((&index, _)) = state.tool_calls.iter().next() {
601                                    if let Some(tool_call) = state.tool_calls.remove(&index) {
602                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
603                                            id: tool_call.id,
604                                            name: tool_call.name,
605                                            input: tool_call.arguments,
606                                            thought_signature: None,
607                                        })));
608                                    }
609                                }
610                            }
611
612                            let stop_reason = match reason.as_str() {
613                                "stop" => LlmStopReason::EndTurn,
614                                "length" => LlmStopReason::MaxTokens,
615                                "tool_calls" => LlmStopReason::ToolUse,
616                                "content_filter" => LlmStopReason::Refusal,
617                                _ => LlmStopReason::EndTurn,
618                            };
619
620                            if let Some(usage) = event.usage {
621                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
622                                    input_tokens: usage.prompt_tokens,
623                                    output_tokens: usage.completion_tokens,
624                                    cache_creation_input_tokens: None,
625                                    cache_read_input_tokens: None,
626                                })));
627                            }
628
629                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
630                        }
631
632                        if let Some(content) = &choice.delta.content {
633                            if !content.is_empty() {
634                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
635                            }
636                        }
637                    }
638
639                    if event.choices.is_empty() {
640                        if let Some(usage) = event.usage {
641                            return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
642                                input_tokens: usage.prompt_tokens,
643                                output_tokens: usage.completion_tokens,
644                                cache_creation_input_tokens: None,
645                                cache_read_input_tokens: None,
646                            })));
647                        }
648                    }
649                }
650
651                continue;
652            }
653
654            match response_stream.next_chunk() {
655                Ok(Some(chunk)) => {
656                    let text = String::from_utf8_lossy(&chunk);
657                    state.buffer.push_str(&text);
658                }
659                Ok(None) => {
660                    if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
661                        state.tool_calls_emitted = true;
662                        let keys: Vec<usize> = state.tool_calls.keys().copied().collect();
663                        if let Some(&key) = keys.first() {
664                            if let Some(tool_call) = state.tool_calls.remove(&key) {
665                                return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
666                                    id: tool_call.id,
667                                    name: tool_call.name,
668                                    input: tool_call.arguments,
669                                    thought_signature: None,
670                                })));
671                            }
672                        }
673                    }
674                    return Ok(None);
675                }
676                Err(e) => {
677                    return Err(format!("Stream error: {}", e));
678                }
679            }
680        }
681    }
682
683    fn llm_stream_completion_close(&mut self, stream_id: &str) {
684        self.streams.lock().unwrap().remove(stream_id);
685    }
686}
687
688zed::register_extension!(OpenAiProvider);