1mod api_key;
2mod model;
3mod registry;
4mod request;
5
6#[cfg(any(test, feature = "test-support"))]
7pub mod fake_provider;
8
9pub use language_model_core::*;
10
11use anyhow::Result;
12use futures::FutureExt;
13use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
14use gpui::{AnyView, App, AsyncApp, Task, Window};
15use icons::IconName;
16use parking_lot::Mutex;
17use std::sync::Arc;
18
19pub use crate::api_key::{ApiKey, ApiKeyState};
20pub use crate::model::*;
21pub use crate::registry::*;
22pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
23pub use env_var::{EnvVar, env_var};
24
25pub fn init(cx: &mut App) {
26 registry::init(cx);
27}
28
29pub struct LanguageModelTextStream {
30 pub message_id: Option<String>,
31 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
32 // Has complete token usage after the stream has finished
33 pub last_token_usage: Arc<Mutex<TokenUsage>>,
34}
35
36impl Default for LanguageModelTextStream {
37 fn default() -> Self {
38 Self {
39 message_id: None,
40 stream: Box::pin(futures::stream::empty()),
41 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
42 }
43 }
44}
45
46pub trait LanguageModel: Send + Sync {
47 fn id(&self) -> LanguageModelId;
48 fn name(&self) -> LanguageModelName;
49 fn provider_id(&self) -> LanguageModelProviderId;
50 fn provider_name(&self) -> LanguageModelProviderName;
51 fn upstream_provider_id(&self) -> LanguageModelProviderId {
52 self.provider_id()
53 }
54 fn upstream_provider_name(&self) -> LanguageModelProviderName {
55 self.provider_name()
56 }
57
58 /// Returns whether this model is the "latest", so we can highlight it in the UI.
59 fn is_latest(&self) -> bool {
60 false
61 }
62
63 fn telemetry_id(&self) -> String;
64
65 fn api_key(&self, _cx: &App) -> Option<String> {
66 None
67 }
68
69 /// Information about the cost of using this model, if available.
70 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
71 None
72 }
73
74 /// Whether this model supports thinking.
75 fn supports_thinking(&self) -> bool {
76 false
77 }
78
79 fn supports_fast_mode(&self) -> bool {
80 false
81 }
82
83 /// Returns the list of supported effort levels that can be used when thinking.
84 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
85 Vec::new()
86 }
87
88 /// Returns the default effort level to use when thinking.
89 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
90 self.supported_effort_levels()
91 .into_iter()
92 .find(|effort_level| effort_level.is_default)
93 }
94
95 /// Whether this model supports images
96 fn supports_images(&self) -> bool;
97
98 /// Whether this model supports tools.
99 fn supports_tools(&self) -> bool;
100
101 /// Whether this model supports choosing which tool to use.
102 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
103
104 /// Returns whether this model or provider supports streaming tool calls;
105 fn supports_streaming_tools(&self) -> bool {
106 false
107 }
108
109 /// Returns whether this model/provider reports accurate split input/output token counts.
110 /// When true, the UI may show separate input/output token indicators.
111 fn supports_split_token_display(&self) -> bool {
112 false
113 }
114
115 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
116 LanguageModelToolSchemaFormat::JsonSchema
117 }
118
119 fn max_token_count(&self) -> u64;
120 fn max_output_tokens(&self) -> Option<u64> {
121 None
122 }
123
124 fn count_tokens(
125 &self,
126 request: LanguageModelRequest,
127 cx: &App,
128 ) -> BoxFuture<'static, Result<u64>>;
129
130 fn stream_completion(
131 &self,
132 request: LanguageModelRequest,
133 cx: &AsyncApp,
134 ) -> BoxFuture<
135 'static,
136 Result<
137 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
138 LanguageModelCompletionError,
139 >,
140 >;
141
142 fn stream_completion_text(
143 &self,
144 request: LanguageModelRequest,
145 cx: &AsyncApp,
146 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
147 let future = self.stream_completion(request, cx);
148
149 async move {
150 let events = future.await?;
151 let mut events = events.fuse();
152 let mut message_id = None;
153 let mut first_item_text = None;
154 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
155
156 if let Some(first_event) = events.next().await {
157 match first_event {
158 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
159 message_id = Some(id);
160 }
161 Ok(LanguageModelCompletionEvent::Text(text)) => {
162 first_item_text = Some(text);
163 }
164 _ => (),
165 }
166 }
167
168 let stream = futures::stream::iter(first_item_text.map(Ok))
169 .chain(events.filter_map({
170 let last_token_usage = last_token_usage.clone();
171 move |result| {
172 let last_token_usage = last_token_usage.clone();
173 async move {
174 match result {
175 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
176 Ok(LanguageModelCompletionEvent::Started) => None,
177 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
178 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
179 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
180 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
181 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
182 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
183 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
184 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
185 ..
186 }) => None,
187 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
188 *last_token_usage.lock() = token_usage;
189 None
190 }
191 Err(err) => Some(Err(err)),
192 }
193 }
194 }
195 }))
196 .boxed();
197
198 Ok(LanguageModelTextStream {
199 message_id,
200 stream,
201 last_token_usage,
202 })
203 }
204 .boxed()
205 }
206
207 fn stream_completion_tool(
208 &self,
209 request: LanguageModelRequest,
210 cx: &AsyncApp,
211 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
212 let future = self.stream_completion(request, cx);
213
214 async move {
215 let events = future.await?;
216 let mut events = events.fuse();
217
218 // Iterate through events until we find a complete ToolUse
219 while let Some(event) = events.next().await {
220 match event {
221 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
222 if tool_use.is_input_complete =>
223 {
224 return Ok(tool_use);
225 }
226 Err(err) => {
227 return Err(err);
228 }
229 _ => {}
230 }
231 }
232
233 // Stream ended without a complete tool use
234 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
235 "Stream ended without receiving a complete tool use"
236 )))
237 }
238 .boxed()
239 }
240
241 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
242 None
243 }
244
245 #[cfg(any(test, feature = "test-support"))]
246 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
247 unimplemented!()
248 }
249}
250
251impl std::fmt::Debug for dyn LanguageModel {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("<dyn LanguageModel>")
254 .field("id", &self.id())
255 .field("name", &self.name())
256 .field("provider_id", &self.provider_id())
257 .field("provider_name", &self.provider_name())
258 .field("upstream_provider_name", &self.upstream_provider_name())
259 .field("upstream_provider_id", &self.upstream_provider_id())
260 .field("upstream_provider_id", &self.upstream_provider_id())
261 .field("supports_streaming_tools", &self.supports_streaming_tools())
262 .finish()
263 }
264}
265
266/// Either a built-in icon name or a path to an external SVG.
267#[derive(Debug, Clone, PartialEq, Eq)]
268pub enum IconOrSvg {
269 /// A built-in icon from Zed's icon set.
270 Icon(IconName),
271 /// Path to a custom SVG icon file.
272 Svg(SharedString),
273}
274
275impl Default for IconOrSvg {
276 fn default() -> Self {
277 Self::Icon(IconName::ZedAssistant)
278 }
279}
280
281pub trait LanguageModelProvider: 'static {
282 fn id(&self) -> LanguageModelProviderId;
283 fn name(&self) -> LanguageModelProviderName;
284 fn icon(&self) -> IconOrSvg {
285 IconOrSvg::default()
286 }
287 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
288 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
289 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
290 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
291 Vec::new()
292 }
293 fn is_authenticated(&self, cx: &App) -> bool;
294 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
295 fn configuration_view(
296 &self,
297 target_agent: ConfigurationViewTargetAgent,
298 window: &mut Window,
299 cx: &mut App,
300 ) -> AnyView;
301 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
302}
303
304#[derive(Default, Clone, PartialEq, Eq)]
305pub enum ConfigurationViewTargetAgent {
306 #[default]
307 ZedAgent,
308 Other(SharedString),
309}
310
311pub trait LanguageModelProviderState: 'static {
312 type ObservableEntity;
313
314 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
315
316 fn subscribe<T: 'static>(
317 &self,
318 cx: &mut gpui::Context<T>,
319 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
320 ) -> Option<gpui::Subscription> {
321 let entity = self.observable_entity()?;
322 Some(cx.observe(&entity, move |this, _, cx| {
323 callback(this, cx);
324 }))
325 }
326}
327
328#[derive(Clone, Debug, PartialEq)]
329pub enum LanguageModelCostInfo {
330 /// Cost per 1,000 input and output tokens
331 TokenCost {
332 input_token_cost_per_1m: f64,
333 output_token_cost_per_1m: f64,
334 },
335 /// Cost per request
336 RequestCost { cost_per_request: f64 },
337}
338
339impl LanguageModelCostInfo {
340 pub fn to_shared_string(&self) -> SharedString {
341 match self {
342 LanguageModelCostInfo::RequestCost { cost_per_request } => {
343 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
344 SharedString::from(cost_str)
345 }
346 LanguageModelCostInfo::TokenCost {
347 input_token_cost_per_1m,
348 output_token_cost_per_1m,
349 } => {
350 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
351 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
352 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
353 }
354 }
355 }
356
357 fn cost_value_to_string(cost: &f64) -> SharedString {
358 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
359 SharedString::from(format!("{:.0}", cost))
360 } else {
361 SharedString::from(format!("{:.2}", cost))
362 }
363 }
364}