openai.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3
  4use serde::{Deserialize, Serialize};
  5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  6use zed_extension_api::{self as zed, *};
  7
  8struct OpenAiProvider {
  9    streams: Mutex<HashMap<String, StreamState>>,
 10    next_stream_id: Mutex<u64>,
 11}
 12
 13struct StreamState {
 14    response_stream: Option<HttpResponseStream>,
 15    buffer: String,
 16    started: bool,
 17    tool_calls: HashMap<usize, AccumulatedToolCall>,
 18    tool_calls_emitted: bool,
 19}
 20
 21#[derive(Clone, Default)]
 22struct AccumulatedToolCall {
 23    id: String,
 24    name: String,
 25    arguments: String,
 26}
 27
 28struct ModelDefinition {
 29    real_id: &'static str,
 30    display_name: &'static str,
 31    max_tokens: u64,
 32    max_output_tokens: Option<u64>,
 33    supports_images: bool,
 34    is_default: bool,
 35    is_default_fast: bool,
 36}
 37
 38const MODELS: &[ModelDefinition] = &[
 39    ModelDefinition {
 40        real_id: "gpt-4o",
 41        display_name: "GPT-4o",
 42        max_tokens: 128_000,
 43        max_output_tokens: Some(16_384),
 44        supports_images: true,
 45        is_default: true,
 46        is_default_fast: false,
 47    },
 48    ModelDefinition {
 49        real_id: "gpt-4o-mini",
 50        display_name: "GPT-4o-mini",
 51        max_tokens: 128_000,
 52        max_output_tokens: Some(16_384),
 53        supports_images: true,
 54        is_default: false,
 55        is_default_fast: true,
 56    },
 57    ModelDefinition {
 58        real_id: "gpt-4.1",
 59        display_name: "GPT-4.1",
 60        max_tokens: 1_047_576,
 61        max_output_tokens: Some(32_768),
 62        supports_images: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gpt-4.1-mini",
 68        display_name: "GPT-4.1-mini",
 69        max_tokens: 1_047_576,
 70        max_output_tokens: Some(32_768),
 71        supports_images: true,
 72        is_default: false,
 73        is_default_fast: false,
 74    },
 75    ModelDefinition {
 76        real_id: "gpt-4.1-nano",
 77        display_name: "GPT-4.1-nano",
 78        max_tokens: 1_047_576,
 79        max_output_tokens: Some(32_768),
 80        supports_images: true,
 81        is_default: false,
 82        is_default_fast: false,
 83    },
 84    ModelDefinition {
 85        real_id: "gpt-5",
 86        display_name: "GPT-5",
 87        max_tokens: 272_000,
 88        max_output_tokens: Some(32_768),
 89        supports_images: true,
 90        is_default: false,
 91        is_default_fast: false,
 92    },
 93    ModelDefinition {
 94        real_id: "gpt-5-mini",
 95        display_name: "GPT-5-mini",
 96        max_tokens: 272_000,
 97        max_output_tokens: Some(32_768),
 98        supports_images: true,
 99        is_default: false,
100        is_default_fast: false,
101    },
102    ModelDefinition {
103        real_id: "o1",
104        display_name: "o1",
105        max_tokens: 200_000,
106        max_output_tokens: Some(100_000),
107        supports_images: true,
108        is_default: false,
109        is_default_fast: false,
110    },
111    ModelDefinition {
112        real_id: "o3",
113        display_name: "o3",
114        max_tokens: 200_000,
115        max_output_tokens: Some(100_000),
116        supports_images: true,
117        is_default: false,
118        is_default_fast: false,
119    },
120    ModelDefinition {
121        real_id: "o3-mini",
122        display_name: "o3-mini",
123        max_tokens: 200_000,
124        max_output_tokens: Some(100_000),
125        supports_images: false,
126        is_default: false,
127        is_default_fast: false,
128    },
129    ModelDefinition {
130        real_id: "o4-mini",
131        display_name: "o4-mini",
132        max_tokens: 200_000,
133        max_output_tokens: Some(100_000),
134        supports_images: true,
135        is_default: false,
136        is_default_fast: false,
137    },
138];
139
140fn get_real_model_id(display_name: &str) -> Option<&'static str> {
141    MODELS
142        .iter()
143        .find(|m| m.display_name == display_name)
144        .map(|m| m.real_id)
145}
146
147#[derive(Serialize)]
148struct OpenAiRequest {
149    model: String,
150    messages: Vec<OpenAiMessage>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    tools: Option<Vec<OpenAiTool>>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    tool_choice: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    temperature: Option<f32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    max_tokens: Option<u64>,
159    #[serde(skip_serializing_if = "Vec::is_empty")]
160    stop: Vec<String>,
161    stream: bool,
162    stream_options: Option<StreamOptions>,
163}
164
165#[derive(Serialize)]
166struct StreamOptions {
167    include_usage: bool,
168}
169
170#[derive(Serialize)]
171#[serde(tag = "role")]
172enum OpenAiMessage {
173    #[serde(rename = "system")]
174    System { content: String },
175    #[serde(rename = "user")]
176    User { content: Vec<OpenAiContentPart> },
177    #[serde(rename = "assistant")]
178    Assistant {
179        #[serde(skip_serializing_if = "Option::is_none")]
180        content: Option<String>,
181        #[serde(skip_serializing_if = "Option::is_none")]
182        tool_calls: Option<Vec<OpenAiToolCall>>,
183    },
184    #[serde(rename = "tool")]
185    Tool {
186        tool_call_id: String,
187        content: String,
188    },
189}
190
191#[derive(Serialize)]
192#[serde(tag = "type")]
193enum OpenAiContentPart {
194    #[serde(rename = "text")]
195    Text { text: String },
196    #[serde(rename = "image_url")]
197    ImageUrl { image_url: ImageUrl },
198}
199
200#[derive(Serialize)]
201struct ImageUrl {
202    url: String,
203}
204
205#[derive(Serialize, Deserialize, Clone)]
206struct OpenAiToolCall {
207    id: String,
208    #[serde(rename = "type")]
209    call_type: String,
210    function: OpenAiFunctionCall,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214struct OpenAiFunctionCall {
215    name: String,
216    arguments: String,
217}
218
219#[derive(Serialize)]
220struct OpenAiTool {
221    #[serde(rename = "type")]
222    tool_type: String,
223    function: OpenAiFunctionDef,
224}
225
226#[derive(Serialize)]
227struct OpenAiFunctionDef {
228    name: String,
229    description: String,
230    parameters: serde_json::Value,
231}
232
233#[derive(Deserialize, Debug)]
234struct OpenAiStreamEvent {
235    choices: Vec<OpenAiChoice>,
236    #[serde(default)]
237    usage: Option<OpenAiUsage>,
238}
239
240#[derive(Deserialize, Debug)]
241struct OpenAiChoice {
242    delta: OpenAiDelta,
243    finish_reason: Option<String>,
244}
245
246#[derive(Deserialize, Debug, Default)]
247struct OpenAiDelta {
248    #[serde(default)]
249    content: Option<String>,
250    #[serde(default)]
251    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
252}
253
254#[derive(Deserialize, Debug)]
255struct OpenAiToolCallDelta {
256    index: usize,
257    #[serde(default)]
258    id: Option<String>,
259    #[serde(default)]
260    function: Option<OpenAiFunctionDelta>,
261}
262
263#[derive(Deserialize, Debug)]
264struct OpenAiFunctionDelta {
265    #[serde(default)]
266    name: Option<String>,
267    #[serde(default)]
268    arguments: Option<String>,
269}
270
271#[derive(Deserialize, Debug)]
272struct OpenAiUsage {
273    prompt_tokens: u64,
274    completion_tokens: u64,
275}
276
277#[allow(dead_code)]
278#[derive(Deserialize, Debug)]
279struct OpenAiError {
280    error: OpenAiErrorDetail,
281}
282
283#[allow(dead_code)]
284#[derive(Deserialize, Debug)]
285struct OpenAiErrorDetail {
286    message: String,
287}
288
289fn convert_request(
290    model_id: &str,
291    request: &LlmCompletionRequest,
292) -> Result<OpenAiRequest, String> {
293    let real_model_id =
294        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
295
296    let mut messages = Vec::new();
297
298    for msg in &request.messages {
299        match msg.role {
300            LlmMessageRole::System => {
301                let text: String = msg
302                    .content
303                    .iter()
304                    .filter_map(|c| match c {
305                        LlmMessageContent::Text(t) => Some(t.as_str()),
306                        _ => None,
307                    })
308                    .collect::<Vec<_>>()
309                    .join("\n");
310                if !text.is_empty() {
311                    messages.push(OpenAiMessage::System { content: text });
312                }
313            }
314            LlmMessageRole::User => {
315                let parts: Vec<OpenAiContentPart> = msg
316                    .content
317                    .iter()
318                    .filter_map(|c| match c {
319                        LlmMessageContent::Text(t) => {
320                            Some(OpenAiContentPart::Text { text: t.clone() })
321                        }
322                        LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl {
323                            image_url: ImageUrl {
324                                url: format!("data:image/png;base64,{}", img.source),
325                            },
326                        }),
327                        LlmMessageContent::ToolResult(_) => None,
328                        _ => None,
329                    })
330                    .collect();
331
332                for content in &msg.content {
333                    if let LlmMessageContent::ToolResult(result) = content {
334                        let content_text = match &result.content {
335                            LlmToolResultContent::Text(t) => t.clone(),
336                            LlmToolResultContent::Image(_) => "[Image]".to_string(),
337                        };
338                        messages.push(OpenAiMessage::Tool {
339                            tool_call_id: result.tool_use_id.clone(),
340                            content: content_text,
341                        });
342                    }
343                }
344
345                if !parts.is_empty() {
346                    messages.push(OpenAiMessage::User { content: parts });
347                }
348            }
349            LlmMessageRole::Assistant => {
350                let mut content_text: Option<String> = None;
351                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
352
353                for c in &msg.content {
354                    match c {
355                        LlmMessageContent::Text(t) => {
356                            content_text = Some(t.clone());
357                        }
358                        LlmMessageContent::ToolUse(tool_use) => {
359                            tool_calls.push(OpenAiToolCall {
360                                id: tool_use.id.clone(),
361                                call_type: "function".to_string(),
362                                function: OpenAiFunctionCall {
363                                    name: tool_use.name.clone(),
364                                    arguments: tool_use.input.clone(),
365                                },
366                            });
367                        }
368                        _ => {}
369                    }
370                }
371
372                messages.push(OpenAiMessage::Assistant {
373                    content: content_text,
374                    tool_calls: if tool_calls.is_empty() {
375                        None
376                    } else {
377                        Some(tool_calls)
378                    },
379                });
380            }
381        }
382    }
383
384    let tools: Option<Vec<OpenAiTool>> = if request.tools.is_empty() {
385        None
386    } else {
387        Some(
388            request
389                .tools
390                .iter()
391                .map(|t| OpenAiTool {
392                    tool_type: "function".to_string(),
393                    function: OpenAiFunctionDef {
394                        name: t.name.clone(),
395                        description: t.description.clone(),
396                        parameters: serde_json::from_str(&t.input_schema)
397                            .unwrap_or(serde_json::Value::Object(Default::default())),
398                    },
399                })
400                .collect(),
401        )
402    };
403
404    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
405        LlmToolChoice::Auto => "auto".to_string(),
406        LlmToolChoice::Any => "required".to_string(),
407        LlmToolChoice::None => "none".to_string(),
408    });
409
410    Ok(OpenAiRequest {
411        model: real_model_id.to_string(),
412        messages,
413        tools,
414        tool_choice,
415        temperature: request.temperature,
416        max_tokens: request.max_tokens,
417        stop: request.stop_sequences.clone(),
418        stream: true,
419        stream_options: Some(StreamOptions {
420            include_usage: true,
421        }),
422    })
423}
424
425fn parse_sse_line(line: &str) -> Option<OpenAiStreamEvent> {
426    if let Some(data) = line.strip_prefix("data: ") {
427        if data == "[DONE]" {
428            return None;
429        }
430        serde_json::from_str(data).ok()
431    } else {
432        None
433    }
434}
435
436impl zed::Extension for OpenAiProvider {
437    fn new() -> Self {
438        Self {
439            streams: Mutex::new(HashMap::new()),
440            next_stream_id: Mutex::new(0),
441        }
442    }
443
444    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
445        vec![LlmProviderInfo {
446            id: "openai".into(),
447            name: "OpenAI".into(),
448            icon: Some("icons/openai.svg".into()),
449        }]
450    }
451
452    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
453        Ok(MODELS
454            .iter()
455            .map(|m| LlmModelInfo {
456                id: m.display_name.to_string(),
457                name: m.display_name.to_string(),
458                max_token_count: m.max_tokens,
459                max_output_tokens: m.max_output_tokens,
460                capabilities: LlmModelCapabilities {
461                    supports_images: m.supports_images,
462                    supports_tools: true,
463                    supports_tool_choice_auto: true,
464                    supports_tool_choice_any: true,
465                    supports_tool_choice_none: true,
466                    supports_thinking: false,
467                    tool_input_format: LlmToolInputFormat::JsonSchema,
468                },
469                is_default: m.is_default,
470                is_default_fast: m.is_default_fast,
471            })
472            .collect())
473    }
474
475    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
476        llm_get_credential("openai").is_some()
477    }
478
479    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
480        Some(
481            r#"# OpenAI Setup
482
483Welcome to **OpenAI**! This extension provides access to OpenAI GPT models.
484
485## Configuration
486
487Enter your OpenAI API key below. You can find your API key at [platform.openai.com/api-keys](https://platform.openai.com/api-keys).
488
489## Available Models
490
491| Display Name | Real Model | Context | Output |
492|--------------|------------|---------|--------|
493| GPT-4o | gpt-4o | 128K | 16K |
494| GPT-4o-mini | gpt-4o-mini | 128K | 16K |
495| GPT-4.1 | gpt-4.1 | 1M | 32K |
496| GPT-4.1-mini | gpt-4.1-mini | 1M | 32K |
497| GPT-5 | gpt-5 | 272K | 32K |
498| GPT-5-mini | gpt-5-mini | 272K | 32K |
499| o1 | o1 | 200K | 100K |
500| o3 | o3 | 200K | 100K |
501| o3-mini | o3-mini | 200K | 100K |
502
503## Features
504
505- ✅ Full streaming support
506- ✅ Tool/function calling
507- ✅ Vision (image inputs)
508- ✅ All OpenAI models
509
510## Pricing
511
512Uses your OpenAI API credits. See [OpenAI pricing](https://openai.com/pricing) for details.
513"#
514            .to_string(),
515        )
516    }
517
518    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
519        llm_delete_credential("openai")
520    }
521
522    fn llm_stream_completion_start(
523        &mut self,
524        _provider_id: &str,
525        model_id: &str,
526        request: &LlmCompletionRequest,
527    ) -> Result<String, String> {
528        let api_key = llm_get_credential("openai").ok_or_else(|| {
529            "No API key configured. Please add your OpenAI API key in settings.".to_string()
530        })?;
531
532        let openai_request = convert_request(model_id, request)?;
533
534        let body = serde_json::to_vec(&openai_request)
535            .map_err(|e| format!("Failed to serialize request: {}", e))?;
536
537        let http_request = HttpRequest {
538            method: HttpMethod::Post,
539            url: "https://api.openai.com/v1/chat/completions".to_string(),
540            headers: vec![
541                ("Content-Type".to_string(), "application/json".to_string()),
542                ("Authorization".to_string(), format!("Bearer {}", api_key)),
543            ],
544            body: Some(body),
545            redirect_policy: RedirectPolicy::FollowAll,
546        };
547
548        let response_stream = http_request
549            .fetch_stream()
550            .map_err(|e| format!("HTTP request failed: {}", e))?;
551
552        let stream_id = {
553            let mut id_counter = self.next_stream_id.lock().unwrap();
554            let id = format!("openai-stream-{}", *id_counter);
555            *id_counter += 1;
556            id
557        };
558
559        self.streams.lock().unwrap().insert(
560            stream_id.clone(),
561            StreamState {
562                response_stream: Some(response_stream),
563                buffer: String::new(),
564                started: false,
565                tool_calls: HashMap::new(),
566                tool_calls_emitted: false,
567            },
568        );
569
570        Ok(stream_id)
571    }
572
573    fn llm_stream_completion_next(
574        &mut self,
575        stream_id: &str,
576    ) -> Result<Option<LlmCompletionEvent>, String> {
577        let mut streams = self.streams.lock().unwrap();
578        let state = streams
579            .get_mut(stream_id)
580            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
581
582        if !state.started {
583            state.started = true;
584            return Ok(Some(LlmCompletionEvent::Started));
585        }
586
587        let response_stream = state
588            .response_stream
589            .as_mut()
590            .ok_or_else(|| "Stream already closed".to_string())?;
591
592        loop {
593            if let Some(newline_pos) = state.buffer.find('\n') {
594                let line = state.buffer[..newline_pos].trim().to_string();
595                state.buffer = state.buffer[newline_pos + 1..].to_string();
596
597                if line.is_empty() {
598                    continue;
599                }
600
601                if let Some(event) = parse_sse_line(&line) {
602                    if let Some(choice) = event.choices.first() {
603                        if let Some(tool_calls) = &choice.delta.tool_calls {
604                            for tc in tool_calls {
605                                let entry = state.tool_calls.entry(tc.index).or_default();
606
607                                if let Some(id) = &tc.id {
608                                    entry.id = id.clone();
609                                }
610
611                                if let Some(func) = &tc.function {
612                                    if let Some(name) = &func.name {
613                                        entry.name = name.clone();
614                                    }
615                                    if let Some(args) = &func.arguments {
616                                        entry.arguments.push_str(args);
617                                    }
618                                }
619                            }
620                        }
621
622                        if let Some(reason) = &choice.finish_reason {
623                            if reason == "tool_calls" && !state.tool_calls_emitted {
624                                state.tool_calls_emitted = true;
625                                if let Some((&index, _)) = state.tool_calls.iter().next() {
626                                    if let Some(tool_call) = state.tool_calls.remove(&index) {
627                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
628                                            id: tool_call.id,
629                                            name: tool_call.name,
630                                            input: tool_call.arguments,
631                                            thought_signature: None,
632                                        })));
633                                    }
634                                }
635                            }
636
637                            let stop_reason = match reason.as_str() {
638                                "stop" => LlmStopReason::EndTurn,
639                                "length" => LlmStopReason::MaxTokens,
640                                "tool_calls" => LlmStopReason::ToolUse,
641                                "content_filter" => LlmStopReason::Refusal,
642                                _ => LlmStopReason::EndTurn,
643                            };
644
645                            if let Some(usage) = event.usage {
646                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
647                                    input_tokens: usage.prompt_tokens,
648                                    output_tokens: usage.completion_tokens,
649                                    cache_creation_input_tokens: None,
650                                    cache_read_input_tokens: None,
651                                })));
652                            }
653
654                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
655                        }
656
657                        if let Some(content) = &choice.delta.content {
658                            if !content.is_empty() {
659                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
660                            }
661                        }
662                    }
663
664                    if event.choices.is_empty() {
665                        if let Some(usage) = event.usage {
666                            return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
667                                input_tokens: usage.prompt_tokens,
668                                output_tokens: usage.completion_tokens,
669                                cache_creation_input_tokens: None,
670                                cache_read_input_tokens: None,
671                            })));
672                        }
673                    }
674                }
675
676                continue;
677            }
678
679            match response_stream.next_chunk() {
680                Ok(Some(chunk)) => {
681                    let text = String::from_utf8_lossy(&chunk);
682                    state.buffer.push_str(&text);
683                }
684                Ok(None) => {
685                    if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
686                        state.tool_calls_emitted = true;
687                        let keys: Vec<usize> = state.tool_calls.keys().copied().collect();
688                        if let Some(&key) = keys.first() {
689                            if let Some(tool_call) = state.tool_calls.remove(&key) {
690                                return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
691                                    id: tool_call.id,
692                                    name: tool_call.name,
693                                    input: tool_call.arguments,
694                                    thought_signature: None,
695                                })));
696                            }
697                        }
698                    }
699                    return Ok(None);
700                }
701                Err(e) => {
702                    return Err(format!("Stream error: {}", e));
703                }
704            }
705        }
706    }
707
708    fn llm_stream_completion_close(&mut self, stream_id: &str) {
709        self.streams.lock().unwrap().remove(stream_id);
710    }
711}
712
713zed::register_extension!(OpenAiProvider);