language_model.rs

  1mod api_key;
  2mod model;
  3mod registry;
  4mod request;
  5
  6#[cfg(any(test, feature = "test-support"))]
  7pub mod fake_provider;
  8
  9pub use language_model_core::*;
 10
 11use anyhow::Result;
 12use futures::FutureExt;
 13use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 14use gpui::{AnyView, App, AsyncApp, Task, Window};
 15use icons::IconName;
 16use parking_lot::Mutex;
 17use std::sync::Arc;
 18
 19pub use crate::api_key::{ApiKey, ApiKeyState};
 20pub use crate::model::*;
 21pub use crate::registry::*;
 22pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
 23pub use env_var::{EnvVar, env_var};
 24
 25pub fn init(cx: &mut App) {
 26    registry::init(cx);
 27}
 28
 29pub struct LanguageModelTextStream {
 30    pub message_id: Option<String>,
 31    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 32    // Has complete token usage after the stream has finished
 33    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 34}
 35
 36impl Default for LanguageModelTextStream {
 37    fn default() -> Self {
 38        Self {
 39            message_id: None,
 40            stream: Box::pin(futures::stream::empty()),
 41            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 42        }
 43    }
 44}
 45
 46pub trait LanguageModel: Send + Sync {
 47    fn id(&self) -> LanguageModelId;
 48    fn name(&self) -> LanguageModelName;
 49    fn provider_id(&self) -> LanguageModelProviderId;
 50    fn provider_name(&self) -> LanguageModelProviderName;
 51    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 52        self.provider_id()
 53    }
 54    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 55        self.provider_name()
 56    }
 57
 58    /// Returns whether this model is the "latest", so we can highlight it in the UI.
 59    fn is_latest(&self) -> bool {
 60        false
 61    }
 62
 63    fn telemetry_id(&self) -> String;
 64
 65    fn api_key(&self, _cx: &App) -> Option<String> {
 66        None
 67    }
 68
 69    /// Information about the cost of using this model, if available.
 70    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
 71        None
 72    }
 73
 74    /// Whether this model supports thinking.
 75    fn supports_thinking(&self) -> bool {
 76        false
 77    }
 78
 79    fn supports_fast_mode(&self) -> bool {
 80        false
 81    }
 82
 83    /// Returns the list of supported effort levels that can be used when thinking.
 84    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 85        Vec::new()
 86    }
 87
 88    /// Returns the default effort level to use when thinking.
 89    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
 90        self.supported_effort_levels()
 91            .into_iter()
 92            .find(|effort_level| effort_level.is_default)
 93    }
 94
 95    /// Whether this model supports images
 96    fn supports_images(&self) -> bool;
 97
 98    /// Whether this model supports tools.
 99    fn supports_tools(&self) -> bool;
100
101    /// Whether this model supports choosing which tool to use.
102    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
103
104    /// Returns whether this model or provider supports streaming tool calls;
105    fn supports_streaming_tools(&self) -> bool {
106        false
107    }
108
109    /// Returns whether this model/provider reports accurate split input/output token counts.
110    /// When true, the UI may show separate input/output token indicators.
111    fn supports_split_token_display(&self) -> bool {
112        false
113    }
114
115    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
116        LanguageModelToolSchemaFormat::JsonSchema
117    }
118
119    fn max_token_count(&self) -> u64;
120    fn max_output_tokens(&self) -> Option<u64> {
121        None
122    }
123
124    fn count_tokens(
125        &self,
126        request: LanguageModelRequest,
127        cx: &App,
128    ) -> BoxFuture<'static, Result<u64>>;
129
130    fn stream_completion(
131        &self,
132        request: LanguageModelRequest,
133        cx: &AsyncApp,
134    ) -> BoxFuture<
135        'static,
136        Result<
137            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
138            LanguageModelCompletionError,
139        >,
140    >;
141
142    fn stream_completion_text(
143        &self,
144        request: LanguageModelRequest,
145        cx: &AsyncApp,
146    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
147        let future = self.stream_completion(request, cx);
148
149        async move {
150            let events = future.await?;
151            let mut events = events.fuse();
152            let mut message_id = None;
153            let mut first_item_text = None;
154            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
155
156            if let Some(first_event) = events.next().await {
157                match first_event {
158                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
159                        message_id = Some(id);
160                    }
161                    Ok(LanguageModelCompletionEvent::Text(text)) => {
162                        first_item_text = Some(text);
163                    }
164                    _ => (),
165                }
166            }
167
168            let stream = futures::stream::iter(first_item_text.map(Ok))
169                .chain(events.filter_map({
170                    let last_token_usage = last_token_usage.clone();
171                    move |result| {
172                        let last_token_usage = last_token_usage.clone();
173                        async move {
174                            match result {
175                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
176                                Ok(LanguageModelCompletionEvent::Started) => None,
177                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
178                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
179                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
180                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
181                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
182                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
183                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
184                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
185                                    ..
186                                }) => None,
187                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
188                                    *last_token_usage.lock() = token_usage;
189                                    None
190                                }
191                                Err(err) => Some(Err(err)),
192                            }
193                        }
194                    }
195                }))
196                .boxed();
197
198            Ok(LanguageModelTextStream {
199                message_id,
200                stream,
201                last_token_usage,
202            })
203        }
204        .boxed()
205    }
206
207    fn stream_completion_tool(
208        &self,
209        request: LanguageModelRequest,
210        cx: &AsyncApp,
211    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
212        let future = self.stream_completion(request, cx);
213
214        async move {
215            let events = future.await?;
216            let mut events = events.fuse();
217
218            // Iterate through events until we find a complete ToolUse
219            while let Some(event) = events.next().await {
220                match event {
221                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
222                        if tool_use.is_input_complete =>
223                    {
224                        return Ok(tool_use);
225                    }
226                    Err(err) => {
227                        return Err(err);
228                    }
229                    _ => {}
230                }
231            }
232
233            // Stream ended without a complete tool use
234            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
235                "Stream ended without receiving a complete tool use"
236            )))
237        }
238        .boxed()
239    }
240
241    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
242        None
243    }
244
245    #[cfg(any(test, feature = "test-support"))]
246    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
247        unimplemented!()
248    }
249}
250
251impl std::fmt::Debug for dyn LanguageModel {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("<dyn LanguageModel>")
254            .field("id", &self.id())
255            .field("name", &self.name())
256            .field("provider_id", &self.provider_id())
257            .field("provider_name", &self.provider_name())
258            .field("upstream_provider_name", &self.upstream_provider_name())
259            .field("upstream_provider_id", &self.upstream_provider_id())
260            .field("upstream_provider_id", &self.upstream_provider_id())
261            .field("supports_streaming_tools", &self.supports_streaming_tools())
262            .finish()
263    }
264}
265
266/// Either a built-in icon name or a path to an external SVG.
267#[derive(Debug, Clone, PartialEq, Eq)]
268pub enum IconOrSvg {
269    /// A built-in icon from Zed's icon set.
270    Icon(IconName),
271    /// Path to a custom SVG icon file.
272    Svg(SharedString),
273}
274
275impl Default for IconOrSvg {
276    fn default() -> Self {
277        Self::Icon(IconName::ZedAssistant)
278    }
279}
280
281pub trait LanguageModelProvider: 'static {
282    fn id(&self) -> LanguageModelProviderId;
283    fn name(&self) -> LanguageModelProviderName;
284    fn icon(&self) -> IconOrSvg {
285        IconOrSvg::default()
286    }
287    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
288    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
289    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
290    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
291        Vec::new()
292    }
293    fn is_authenticated(&self, cx: &App) -> bool;
294    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
295    fn configuration_view(
296        &self,
297        target_agent: ConfigurationViewTargetAgent,
298        window: &mut Window,
299        cx: &mut App,
300    ) -> AnyView;
301    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
302}
303
304#[derive(Default, Clone, PartialEq, Eq)]
305pub enum ConfigurationViewTargetAgent {
306    #[default]
307    ZedAgent,
308    Other(SharedString),
309}
310
311pub trait LanguageModelProviderState: 'static {
312    type ObservableEntity;
313
314    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
315
316    fn subscribe<T: 'static>(
317        &self,
318        cx: &mut gpui::Context<T>,
319        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
320    ) -> Option<gpui::Subscription> {
321        let entity = self.observable_entity()?;
322        Some(cx.observe(&entity, move |this, _, cx| {
323            callback(this, cx);
324        }))
325    }
326}
327
328#[derive(Clone, Debug, PartialEq)]
329pub enum LanguageModelCostInfo {
330    /// Cost per 1,000 input and output tokens
331    TokenCost {
332        input_token_cost_per_1m: f64,
333        output_token_cost_per_1m: f64,
334    },
335    /// Cost per request
336    RequestCost { cost_per_request: f64 },
337}
338
339impl LanguageModelCostInfo {
340    pub fn to_shared_string(&self) -> SharedString {
341        match self {
342            LanguageModelCostInfo::RequestCost { cost_per_request } => {
343                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
344                SharedString::from(cost_str)
345            }
346            LanguageModelCostInfo::TokenCost {
347                input_token_cost_per_1m,
348                output_token_cost_per_1m,
349            } => {
350                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
351                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
352                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
353            }
354        }
355    }
356
357    fn cost_value_to_string(cost: &f64) -> SharedString {
358        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
359            SharedString::from(format!("{:.0}", cost))
360        } else {
361            SharedString::from(format!("{:.2}", cost))
362        }
363    }
364}