openai.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3
  4use serde::{Deserialize, Serialize};
  5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  6use zed_extension_api::{self as zed, *};
  7
  8struct OpenAiProvider {
  9    streams: Mutex<HashMap<String, StreamState>>,
 10    next_stream_id: Mutex<u64>,
 11}
 12
 13struct StreamState {
 14    response_stream: Option<HttpResponseStream>,
 15    buffer: String,
 16    started: bool,
 17    tool_calls: HashMap<usize, AccumulatedToolCall>,
 18    tool_calls_emitted: bool,
 19}
 20
 21#[derive(Clone, Default)]
 22struct AccumulatedToolCall {
 23    id: String,
 24    name: String,
 25    arguments: String,
 26}
 27
 28struct ModelDefinition {
 29    real_id: &'static str,
 30    display_name: &'static str,
 31    max_tokens: u64,
 32    max_output_tokens: Option<u64>,
 33    supports_images: bool,
 34    is_default: bool,
 35    is_default_fast: bool,
 36}
 37
 38const MODELS: &[ModelDefinition] = &[
 39    ModelDefinition {
 40        real_id: "gpt-4o",
 41        display_name: "GPT-4o",
 42        max_tokens: 128_000,
 43        max_output_tokens: Some(16_384),
 44        supports_images: true,
 45        is_default: true,
 46        is_default_fast: false,
 47    },
 48    ModelDefinition {
 49        real_id: "gpt-4o-mini",
 50        display_name: "GPT-4o-mini",
 51        max_tokens: 128_000,
 52        max_output_tokens: Some(16_384),
 53        supports_images: true,
 54        is_default: false,
 55        is_default_fast: true,
 56    },
 57    ModelDefinition {
 58        real_id: "gpt-4.1",
 59        display_name: "GPT-4.1",
 60        max_tokens: 1_047_576,
 61        max_output_tokens: Some(32_768),
 62        supports_images: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        real_id: "gpt-4.1-mini",
 68        display_name: "GPT-4.1-mini",
 69        max_tokens: 1_047_576,
 70        max_output_tokens: Some(32_768),
 71        supports_images: true,
 72        is_default: false,
 73        is_default_fast: false,
 74    },
 75    ModelDefinition {
 76        real_id: "gpt-4.1-nano",
 77        display_name: "GPT-4.1-nano",
 78        max_tokens: 1_047_576,
 79        max_output_tokens: Some(32_768),
 80        supports_images: true,
 81        is_default: false,
 82        is_default_fast: false,
 83    },
 84    ModelDefinition {
 85        real_id: "gpt-5",
 86        display_name: "GPT-5",
 87        max_tokens: 272_000,
 88        max_output_tokens: Some(32_768),
 89        supports_images: true,
 90        is_default: false,
 91        is_default_fast: false,
 92    },
 93    ModelDefinition {
 94        real_id: "gpt-5-mini",
 95        display_name: "GPT-5-mini",
 96        max_tokens: 272_000,
 97        max_output_tokens: Some(32_768),
 98        supports_images: true,
 99        is_default: false,
100        is_default_fast: false,
101    },
102    ModelDefinition {
103        real_id: "o1",
104        display_name: "o1",
105        max_tokens: 200_000,
106        max_output_tokens: Some(100_000),
107        supports_images: true,
108        is_default: false,
109        is_default_fast: false,
110    },
111    ModelDefinition {
112        real_id: "o3",
113        display_name: "o3",
114        max_tokens: 200_000,
115        max_output_tokens: Some(100_000),
116        supports_images: true,
117        is_default: false,
118        is_default_fast: false,
119    },
120    ModelDefinition {
121        real_id: "o3-mini",
122        display_name: "o3-mini",
123        max_tokens: 200_000,
124        max_output_tokens: Some(100_000),
125        supports_images: false,
126        is_default: false,
127        is_default_fast: false,
128    },
129    ModelDefinition {
130        real_id: "o4-mini",
131        display_name: "o4-mini",
132        max_tokens: 200_000,
133        max_output_tokens: Some(100_000),
134        supports_images: true,
135        is_default: false,
136        is_default_fast: false,
137    },
138];
139
140fn get_real_model_id(display_name: &str) -> Option<&'static str> {
141    MODELS
142        .iter()
143        .find(|m| m.display_name == display_name)
144        .map(|m| m.real_id)
145}
146
147#[derive(Serialize)]
148struct OpenAiRequest {
149    model: String,
150    messages: Vec<OpenAiMessage>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    tools: Option<Vec<OpenAiTool>>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    tool_choice: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    temperature: Option<f32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    max_tokens: Option<u64>,
159    #[serde(skip_serializing_if = "Vec::is_empty")]
160    stop: Vec<String>,
161    stream: bool,
162    stream_options: Option<StreamOptions>,
163}
164
165#[derive(Serialize)]
166struct StreamOptions {
167    include_usage: bool,
168}
169
170#[derive(Serialize)]
171#[serde(tag = "role")]
172enum OpenAiMessage {
173    #[serde(rename = "system")]
174    System { content: String },
175    #[serde(rename = "user")]
176    User { content: Vec<OpenAiContentPart> },
177    #[serde(rename = "assistant")]
178    Assistant {
179        #[serde(skip_serializing_if = "Option::is_none")]
180        content: Option<String>,
181        #[serde(skip_serializing_if = "Option::is_none")]
182        tool_calls: Option<Vec<OpenAiToolCall>>,
183    },
184    #[serde(rename = "tool")]
185    Tool {
186        tool_call_id: String,
187        content: String,
188    },
189}
190
191#[derive(Serialize)]
192#[serde(tag = "type")]
193enum OpenAiContentPart {
194    #[serde(rename = "text")]
195    Text { text: String },
196    #[serde(rename = "image_url")]
197    ImageUrl { image_url: ImageUrl },
198}
199
200#[derive(Serialize)]
201struct ImageUrl {
202    url: String,
203}
204
205#[derive(Serialize, Deserialize, Clone)]
206struct OpenAiToolCall {
207    id: String,
208    #[serde(rename = "type")]
209    call_type: String,
210    function: OpenAiFunctionCall,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214struct OpenAiFunctionCall {
215    name: String,
216    arguments: String,
217}
218
219#[derive(Serialize)]
220struct OpenAiTool {
221    #[serde(rename = "type")]
222    tool_type: String,
223    function: OpenAiFunctionDef,
224}
225
226#[derive(Serialize)]
227struct OpenAiFunctionDef {
228    name: String,
229    description: String,
230    parameters: serde_json::Value,
231}
232
233#[derive(Deserialize, Debug)]
234struct OpenAiStreamEvent {
235    choices: Vec<OpenAiChoice>,
236    #[serde(default)]
237    usage: Option<OpenAiUsage>,
238}
239
240#[derive(Deserialize, Debug)]
241struct OpenAiChoice {
242    delta: OpenAiDelta,
243    finish_reason: Option<String>,
244}
245
246#[derive(Deserialize, Debug, Default)]
247struct OpenAiDelta {
248    #[serde(default)]
249    content: Option<String>,
250    #[serde(default)]
251    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
252}
253
254#[derive(Deserialize, Debug)]
255struct OpenAiToolCallDelta {
256    index: usize,
257    #[serde(default)]
258    id: Option<String>,
259    #[serde(default)]
260    function: Option<OpenAiFunctionDelta>,
261}
262
263#[derive(Deserialize, Debug)]
264struct OpenAiFunctionDelta {
265    #[serde(default)]
266    name: Option<String>,
267    #[serde(default)]
268    arguments: Option<String>,
269}
270
271#[derive(Deserialize, Debug)]
272struct OpenAiUsage {
273    prompt_tokens: u64,
274    completion_tokens: u64,
275}
276
277#[allow(dead_code)]
278#[derive(Deserialize, Debug)]
279struct OpenAiError {
280    error: OpenAiErrorDetail,
281}
282
283#[allow(dead_code)]
284#[derive(Deserialize, Debug)]
285struct OpenAiErrorDetail {
286    message: String,
287}
288
289fn convert_request(
290    model_id: &str,
291    request: &LlmCompletionRequest,
292) -> Result<OpenAiRequest, String> {
293    let real_model_id =
294        get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
295
296    let mut messages = Vec::new();
297
298    for msg in &request.messages {
299        match msg.role {
300            LlmMessageRole::System => {
301                let text: String = msg
302                    .content
303                    .iter()
304                    .filter_map(|c| match c {
305                        LlmMessageContent::Text(t) => Some(t.as_str()),
306                        _ => None,
307                    })
308                    .collect::<Vec<_>>()
309                    .join("\n");
310                if !text.is_empty() {
311                    messages.push(OpenAiMessage::System { content: text });
312                }
313            }
314            LlmMessageRole::User => {
315                let parts: Vec<OpenAiContentPart> = msg
316                    .content
317                    .iter()
318                    .filter_map(|c| match c {
319                        LlmMessageContent::Text(t) => {
320                            Some(OpenAiContentPart::Text { text: t.clone() })
321                        }
322                        LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl {
323                            image_url: ImageUrl {
324                                url: format!("data:image/png;base64,{}", img.source),
325                            },
326                        }),
327                        LlmMessageContent::ToolResult(_) => None,
328                        _ => None,
329                    })
330                    .collect();
331
332                for content in &msg.content {
333                    if let LlmMessageContent::ToolResult(result) = content {
334                        let content_text = match &result.content {
335                            LlmToolResultContent::Text(t) => t.clone(),
336                            LlmToolResultContent::Image(_) => "[Image]".to_string(),
337                        };
338                        messages.push(OpenAiMessage::Tool {
339                            tool_call_id: result.tool_use_id.clone(),
340                            content: content_text,
341                        });
342                    }
343                }
344
345                if !parts.is_empty() {
346                    messages.push(OpenAiMessage::User { content: parts });
347                }
348            }
349            LlmMessageRole::Assistant => {
350                let mut content_text: Option<String> = None;
351                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
352
353                for c in &msg.content {
354                    match c {
355                        LlmMessageContent::Text(t) => {
356                            content_text = Some(t.clone());
357                        }
358                        LlmMessageContent::ToolUse(tool_use) => {
359                            tool_calls.push(OpenAiToolCall {
360                                id: tool_use.id.clone(),
361                                call_type: "function".to_string(),
362                                function: OpenAiFunctionCall {
363                                    name: tool_use.name.clone(),
364                                    arguments: tool_use.input.clone(),
365                                },
366                            });
367                        }
368                        _ => {}
369                    }
370                }
371
372                messages.push(OpenAiMessage::Assistant {
373                    content: content_text,
374                    tool_calls: if tool_calls.is_empty() {
375                        None
376                    } else {
377                        Some(tool_calls)
378                    },
379                });
380            }
381        }
382    }
383
384    let tools: Option<Vec<OpenAiTool>> = if request.tools.is_empty() {
385        None
386    } else {
387        Some(
388            request
389                .tools
390                .iter()
391                .map(|t| OpenAiTool {
392                    tool_type: "function".to_string(),
393                    function: OpenAiFunctionDef {
394                        name: t.name.clone(),
395                        description: t.description.clone(),
396                        parameters: serde_json::from_str(&t.input_schema)
397                            .unwrap_or(serde_json::Value::Object(Default::default())),
398                    },
399                })
400                .collect(),
401        )
402    };
403
404    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
405        LlmToolChoice::Auto => "auto".to_string(),
406        LlmToolChoice::Any => "required".to_string(),
407        LlmToolChoice::None => "none".to_string(),
408    });
409
410    Ok(OpenAiRequest {
411        model: real_model_id.to_string(),
412        messages,
413        tools,
414        tool_choice,
415        temperature: request.temperature,
416        max_tokens: request.max_tokens,
417        stop: request.stop_sequences.clone(),
418        stream: true,
419        stream_options: Some(StreamOptions {
420            include_usage: true,
421        }),
422    })
423}
424
425fn parse_sse_line(line: &str) -> Option<OpenAiStreamEvent> {
426    if let Some(data) = line.strip_prefix("data: ") {
427        if data == "[DONE]" {
428            return None;
429        }
430        serde_json::from_str(data).ok()
431    } else {
432        None
433    }
434}
435
436impl zed::Extension for OpenAiProvider {
437    fn new() -> Self {
438        Self {
439            streams: Mutex::new(HashMap::new()),
440            next_stream_id: Mutex::new(0),
441        }
442    }
443
444    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
445        vec![LlmProviderInfo {
446            id: "openai".into(),
447            name: "OpenAI".into(),
448            icon: Some("icons/openai.svg".into()),
449        }]
450    }
451
452    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
453        Ok(MODELS
454            .iter()
455            .map(|m| LlmModelInfo {
456                id: m.display_name.to_string(),
457                name: m.display_name.to_string(),
458                max_token_count: m.max_tokens,
459                max_output_tokens: m.max_output_tokens,
460                capabilities: LlmModelCapabilities {
461                    supports_images: m.supports_images,
462                    supports_tools: true,
463                    supports_tool_choice_auto: true,
464                    supports_tool_choice_any: true,
465                    supports_tool_choice_none: true,
466                    supports_thinking: false,
467                    tool_input_format: LlmToolInputFormat::JsonSchema,
468                },
469                is_default: m.is_default,
470                is_default_fast: m.is_default_fast,
471            })
472            .collect())
473    }
474
475    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
476        llm_get_credential("openai").is_some()
477    }
478
479    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
480        Some(
481            "[Create an API key](https://platform.openai.com/api-keys) to use OpenAI as your LLM provider.".to_string(),
482        )
483    }
484
485    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
486        llm_delete_credential("openai")
487    }
488
489    fn llm_stream_completion_start(
490        &mut self,
491        _provider_id: &str,
492        model_id: &str,
493        request: &LlmCompletionRequest,
494    ) -> Result<String, String> {
495        let api_key = llm_get_credential("openai").ok_or_else(|| {
496            "No API key configured. Please add your OpenAI API key in settings.".to_string()
497        })?;
498
499        let openai_request = convert_request(model_id, request)?;
500
501        let body = serde_json::to_vec(&openai_request)
502            .map_err(|e| format!("Failed to serialize request: {}", e))?;
503
504        let http_request = HttpRequest {
505            method: HttpMethod::Post,
506            url: "https://api.openai.com/v1/chat/completions".to_string(),
507            headers: vec![
508                ("Content-Type".to_string(), "application/json".to_string()),
509                ("Authorization".to_string(), format!("Bearer {}", api_key)),
510            ],
511            body: Some(body),
512            redirect_policy: RedirectPolicy::FollowAll,
513        };
514
515        let response_stream = http_request
516            .fetch_stream()
517            .map_err(|e| format!("HTTP request failed: {}", e))?;
518
519        let stream_id = {
520            let mut id_counter = self.next_stream_id.lock().unwrap();
521            let id = format!("openai-stream-{}", *id_counter);
522            *id_counter += 1;
523            id
524        };
525
526        self.streams.lock().unwrap().insert(
527            stream_id.clone(),
528            StreamState {
529                response_stream: Some(response_stream),
530                buffer: String::new(),
531                started: false,
532                tool_calls: HashMap::new(),
533                tool_calls_emitted: false,
534            },
535        );
536
537        Ok(stream_id)
538    }
539
540    fn llm_stream_completion_next(
541        &mut self,
542        stream_id: &str,
543    ) -> Result<Option<LlmCompletionEvent>, String> {
544        let mut streams = self.streams.lock().unwrap();
545        let state = streams
546            .get_mut(stream_id)
547            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
548
549        if !state.started {
550            state.started = true;
551            return Ok(Some(LlmCompletionEvent::Started));
552        }
553
554        let response_stream = state
555            .response_stream
556            .as_mut()
557            .ok_or_else(|| "Stream already closed".to_string())?;
558
559        loop {
560            if let Some(newline_pos) = state.buffer.find('\n') {
561                let line = state.buffer[..newline_pos].trim().to_string();
562                state.buffer = state.buffer[newline_pos + 1..].to_string();
563
564                if line.is_empty() {
565                    continue;
566                }
567
568                if let Some(event) = parse_sse_line(&line) {
569                    if let Some(choice) = event.choices.first() {
570                        if let Some(tool_calls) = &choice.delta.tool_calls {
571                            for tc in tool_calls {
572                                let entry = state.tool_calls.entry(tc.index).or_default();
573
574                                if let Some(id) = &tc.id {
575                                    entry.id = id.clone();
576                                }
577
578                                if let Some(func) = &tc.function {
579                                    if let Some(name) = &func.name {
580                                        entry.name = name.clone();
581                                    }
582                                    if let Some(args) = &func.arguments {
583                                        entry.arguments.push_str(args);
584                                    }
585                                }
586                            }
587                        }
588
589                        if let Some(reason) = &choice.finish_reason {
590                            if reason == "tool_calls" && !state.tool_calls_emitted {
591                                state.tool_calls_emitted = true;
592                                if let Some((&index, _)) = state.tool_calls.iter().next() {
593                                    if let Some(tool_call) = state.tool_calls.remove(&index) {
594                                        return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
595                                            id: tool_call.id,
596                                            name: tool_call.name,
597                                            input: tool_call.arguments,
598                                            is_input_complete: true,
599                                            thought_signature: None,
600                                        })));
601                                    }
602                                }
603                            }
604
605                            let stop_reason = match reason.as_str() {
606                                "stop" => LlmStopReason::EndTurn,
607                                "length" => LlmStopReason::MaxTokens,
608                                "tool_calls" => LlmStopReason::ToolUse,
609                                "content_filter" => LlmStopReason::Refusal,
610                                _ => LlmStopReason::EndTurn,
611                            };
612
613                            if let Some(usage) = event.usage {
614                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
615                                    input_tokens: usage.prompt_tokens,
616                                    output_tokens: usage.completion_tokens,
617                                    cache_creation_input_tokens: None,
618                                    cache_read_input_tokens: None,
619                                })));
620                            }
621
622                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
623                        }
624
625                        if let Some(content) = &choice.delta.content {
626                            if !content.is_empty() {
627                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
628                            }
629                        }
630                    }
631
632                    if event.choices.is_empty() {
633                        if let Some(usage) = event.usage {
634                            return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
635                                input_tokens: usage.prompt_tokens,
636                                output_tokens: usage.completion_tokens,
637                                cache_creation_input_tokens: None,
638                                cache_read_input_tokens: None,
639                            })));
640                        }
641                    }
642                }
643
644                continue;
645            }
646
647            match response_stream.next_chunk() {
648                Ok(Some(chunk)) => {
649                    let text = String::from_utf8_lossy(&chunk);
650                    state.buffer.push_str(&text);
651                }
652                Ok(None) => {
653                    if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
654                        state.tool_calls_emitted = true;
655                        let keys: Vec<usize> = state.tool_calls.keys().copied().collect();
656                        if let Some(&key) = keys.first() {
657                            if let Some(tool_call) = state.tool_calls.remove(&key) {
658                                return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
659                                    id: tool_call.id,
660                                    name: tool_call.name,
661                                    input: tool_call.arguments,
662                                    is_input_complete: true,
663                                    thought_signature: None,
664                                })));
665                            }
666                        }
667                    }
668                    return Ok(None);
669                }
670                Err(e) => {
671                    return Err(format!("Stream error: {}", e));
672                }
673            }
674        }
675    }
676
677    fn llm_stream_completion_close(&mut self, stream_id: &str) {
678        self.streams.lock().unwrap().remove(stream_id);
679    }
680}
681
682zed::register_extension!(OpenAiProvider);