language_model.rs

  1mod api_key;
  2mod model;
  3mod registry;
  4mod request;
  5
  6#[cfg(any(test, feature = "test-support"))]
  7pub mod fake_provider;
  8
  9pub use language_model_core::*;
 10
 11use anyhow::Result;
 12use futures::FutureExt;
 13use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 14use gpui::{AnyView, App, AsyncApp, Task, Window};
 15use icons::IconName;
 16use parking_lot::Mutex;
 17use std::sync::Arc;
 18
 19pub use crate::api_key::{ApiKey, ApiKeyState};
 20pub use crate::model::*;
 21pub use crate::registry::*;
 22pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
 23pub use env_var::{EnvVar, env_var};
 24
 25pub fn init(cx: &mut App) {
 26    registry::init(cx);
 27}
 28
 29pub struct LanguageModelTextStream {
 30    pub message_id: Option<String>,
 31    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 32    // Has complete token usage after the stream has finished
 33    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 34}
 35
 36impl Default for LanguageModelTextStream {
 37    fn default() -> Self {
 38        Self {
 39            message_id: None,
 40            stream: Box::pin(futures::stream::empty()),
 41            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 42        }
 43    }
 44}
 45
 46pub trait LanguageModel: Send + Sync {
 47    fn id(&self) -> LanguageModelId;
 48    fn name(&self) -> LanguageModelName;
 49    fn provider_id(&self) -> LanguageModelProviderId;
 50    fn provider_name(&self) -> LanguageModelProviderName;
 51    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 52        self.provider_id()
 53    }
 54    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 55        self.provider_name()
 56    }
 57
 58    /// Returns whether this model is the "latest", so we can highlight it in the UI.
 59    fn is_latest(&self) -> bool {
 60        false
 61    }
 62
 63    fn telemetry_id(&self) -> String;
 64
 65    fn api_key(&self, _cx: &App) -> Option<String> {
 66        None
 67    }
 68
 69    /// Information about the cost of using this model, if available.
 70    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
 71        None
 72    }
 73
 74    /// Whether this model supports thinking.
 75    fn supports_thinking(&self) -> bool {
 76        false
 77    }
 78
 79    fn supports_fast_mode(&self) -> bool {
 80        false
 81    }
 82
 83    /// Returns the list of supported effort levels that can be used when thinking.
 84    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 85        Vec::new()
 86    }
 87
 88    /// Returns the default effort level to use when thinking.
 89    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
 90        self.supported_effort_levels()
 91            .into_iter()
 92            .find(|effort_level| effort_level.is_default)
 93    }
 94
 95    /// Whether this model supports images
 96    fn supports_images(&self) -> bool;
 97
 98    /// Whether this model supports tools.
 99    fn supports_tools(&self) -> bool;
100
101    /// Whether this model supports choosing which tool to use.
102    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
103
104    /// Returns whether this model or provider supports streaming tool calls;
105    fn supports_streaming_tools(&self) -> bool {
106        false
107    }
108
109    /// Returns whether this model/provider reports accurate split input/output token counts.
110    /// When true, the UI may show separate input/output token indicators.
111    fn supports_split_token_display(&self) -> bool {
112        false
113    }
114
115    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
116        LanguageModelToolSchemaFormat::JsonSchema
117    }
118
119    fn max_token_count(&self) -> u64;
120    fn max_output_tokens(&self) -> Option<u64> {
121        None
122    }
123
124    fn stream_completion(
125        &self,
126        request: LanguageModelRequest,
127        cx: &AsyncApp,
128    ) -> BoxFuture<
129        'static,
130        Result<
131            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
132            LanguageModelCompletionError,
133        >,
134    >;
135
136    fn stream_completion_text(
137        &self,
138        request: LanguageModelRequest,
139        cx: &AsyncApp,
140    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
141        let future = self.stream_completion(request, cx);
142
143        async move {
144            let events = future.await?;
145            let mut events = events.fuse();
146            let mut message_id = None;
147            let mut first_item_text = None;
148            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
149
150            if let Some(first_event) = events.next().await {
151                match first_event {
152                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
153                        message_id = Some(id);
154                    }
155                    Ok(LanguageModelCompletionEvent::Text(text)) => {
156                        first_item_text = Some(text);
157                    }
158                    _ => (),
159                }
160            }
161
162            let stream = futures::stream::iter(first_item_text.map(Ok))
163                .chain(events.filter_map({
164                    let last_token_usage = last_token_usage.clone();
165                    move |result| {
166                        let last_token_usage = last_token_usage.clone();
167                        async move {
168                            match result {
169                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
170                                Ok(LanguageModelCompletionEvent::Started) => None,
171                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
172                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
173                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
174                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
175                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
176                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
177                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
178                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
179                                    ..
180                                }) => None,
181                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
182                                    *last_token_usage.lock() = token_usage;
183                                    None
184                                }
185                                Err(err) => Some(Err(err)),
186                            }
187                        }
188                    }
189                }))
190                .boxed();
191
192            Ok(LanguageModelTextStream {
193                message_id,
194                stream,
195                last_token_usage,
196            })
197        }
198        .boxed()
199    }
200
201    fn stream_completion_tool(
202        &self,
203        request: LanguageModelRequest,
204        cx: &AsyncApp,
205    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
206        let future = self.stream_completion(request, cx);
207
208        async move {
209            let events = future.await?;
210            let mut events = events.fuse();
211
212            // Iterate through events until we find a complete ToolUse
213            while let Some(event) = events.next().await {
214                match event {
215                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
216                        if tool_use.is_input_complete =>
217                    {
218                        return Ok(tool_use);
219                    }
220                    Err(err) => {
221                        return Err(err);
222                    }
223                    _ => {}
224                }
225            }
226
227            // Stream ended without a complete tool use
228            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
229                "Stream ended without receiving a complete tool use"
230            )))
231        }
232        .boxed()
233    }
234
235    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
236        None
237    }
238
239    #[cfg(any(test, feature = "test-support"))]
240    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
241        unimplemented!()
242    }
243}
244
245impl std::fmt::Debug for dyn LanguageModel {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.debug_struct("<dyn LanguageModel>")
248            .field("id", &self.id())
249            .field("name", &self.name())
250            .field("provider_id", &self.provider_id())
251            .field("provider_name", &self.provider_name())
252            .field("upstream_provider_name", &self.upstream_provider_name())
253            .field("upstream_provider_id", &self.upstream_provider_id())
254            .field("upstream_provider_id", &self.upstream_provider_id())
255            .field("supports_streaming_tools", &self.supports_streaming_tools())
256            .finish()
257    }
258}
259
260/// Either a built-in icon name or a path to an external SVG.
261#[derive(Debug, Clone, PartialEq, Eq)]
262pub enum IconOrSvg {
263    /// A built-in icon from Zed's icon set.
264    Icon(IconName),
265    /// Path to a custom SVG icon file.
266    Svg(SharedString),
267}
268
269impl Default for IconOrSvg {
270    fn default() -> Self {
271        Self::Icon(IconName::ZedAssistant)
272    }
273}
274
275pub trait LanguageModelProvider: 'static {
276    fn id(&self) -> LanguageModelProviderId;
277    fn name(&self) -> LanguageModelProviderName;
278    fn icon(&self) -> IconOrSvg {
279        IconOrSvg::default()
280    }
281    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
282    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
283    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
284    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
285        Vec::new()
286    }
287    fn is_authenticated(&self, cx: &App) -> bool;
288    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
289    fn configuration_view(
290        &self,
291        target_agent: ConfigurationViewTargetAgent,
292        window: &mut Window,
293        cx: &mut App,
294    ) -> AnyView;
295    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
296}
297
298#[derive(Default, Clone, PartialEq, Eq)]
299pub enum ConfigurationViewTargetAgent {
300    #[default]
301    ZedAgent,
302    Other(SharedString),
303}
304
305pub trait LanguageModelProviderState: 'static {
306    type ObservableEntity;
307
308    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
309
310    fn subscribe<T: 'static>(
311        &self,
312        cx: &mut gpui::Context<T>,
313        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
314    ) -> Option<gpui::Subscription> {
315        let entity = self.observable_entity()?;
316        Some(cx.observe(&entity, move |this, _, cx| {
317            callback(this, cx);
318        }))
319    }
320}
321
322#[derive(Clone, Debug, PartialEq)]
323pub enum LanguageModelCostInfo {
324    /// Cost per 1,000 input and output tokens
325    TokenCost {
326        input_token_cost_per_1m: f64,
327        output_token_cost_per_1m: f64,
328    },
329    /// Cost per request
330    RequestCost { cost_per_request: f64 },
331}
332
333impl LanguageModelCostInfo {
334    pub fn to_shared_string(&self) -> SharedString {
335        match self {
336            LanguageModelCostInfo::RequestCost { cost_per_request } => {
337                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
338                SharedString::from(cost_str)
339            }
340            LanguageModelCostInfo::TokenCost {
341                input_token_cost_per_1m,
342                output_token_cost_per_1m,
343            } => {
344                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
345                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
346                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
347            }
348        }
349    }
350
351    fn cost_value_to_string(cost: &f64) -> SharedString {
352        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
353            SharedString::from(format!("{:.0}", cost))
354        } else {
355            SharedString::from(format!("{:.2}", cost))
356        }
357    }
358}