1mod api_key;
2mod model;
3mod registry;
4mod request;
5
6#[cfg(any(test, feature = "test-support"))]
7pub mod fake_provider;
8
9pub use language_model_core::*;
10
11use anyhow::Result;
12use futures::FutureExt;
13use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
14use gpui::{AnyView, App, AsyncApp, Task, Window};
15use icons::IconName;
16use parking_lot::Mutex;
17use std::sync::Arc;
18
19pub use crate::api_key::{ApiKey, ApiKeyState};
20pub use crate::model::*;
21pub use crate::registry::*;
22pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
23pub use env_var::{EnvVar, env_var};
24
25pub fn init(cx: &mut App) {
26 registry::init(cx);
27}
28
29pub struct LanguageModelTextStream {
30 pub message_id: Option<String>,
31 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
32 // Has complete token usage after the stream has finished
33 pub last_token_usage: Arc<Mutex<TokenUsage>>,
34}
35
36impl Default for LanguageModelTextStream {
37 fn default() -> Self {
38 Self {
39 message_id: None,
40 stream: Box::pin(futures::stream::empty()),
41 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
42 }
43 }
44}
45
46pub trait LanguageModel: Send + Sync {
47 fn id(&self) -> LanguageModelId;
48 fn name(&self) -> LanguageModelName;
49 fn provider_id(&self) -> LanguageModelProviderId;
50 fn provider_name(&self) -> LanguageModelProviderName;
51 fn upstream_provider_id(&self) -> LanguageModelProviderId {
52 self.provider_id()
53 }
54 fn upstream_provider_name(&self) -> LanguageModelProviderName {
55 self.provider_name()
56 }
57
58 /// Returns whether this model is the "latest", so we can highlight it in the UI.
59 fn is_latest(&self) -> bool {
60 false
61 }
62
63 fn telemetry_id(&self) -> String;
64
65 fn api_key(&self, _cx: &App) -> Option<String> {
66 None
67 }
68
69 /// Information about the cost of using this model, if available.
70 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
71 None
72 }
73
74 /// Whether this model supports thinking.
75 fn supports_thinking(&self) -> bool {
76 false
77 }
78
79 fn supports_fast_mode(&self) -> bool {
80 false
81 }
82
83 /// Returns the list of supported effort levels that can be used when thinking.
84 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
85 Vec::new()
86 }
87
88 /// Returns the default effort level to use when thinking.
89 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
90 self.supported_effort_levels()
91 .into_iter()
92 .find(|effort_level| effort_level.is_default)
93 }
94
95 /// Whether this model supports images
96 fn supports_images(&self) -> bool;
97
98 /// Whether this model supports tools.
99 fn supports_tools(&self) -> bool;
100
101 /// Whether this model supports choosing which tool to use.
102 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
103
104 /// Returns whether this model or provider supports streaming tool calls;
105 fn supports_streaming_tools(&self) -> bool {
106 false
107 }
108
109 /// Returns whether this model/provider reports accurate split input/output token counts.
110 /// When true, the UI may show separate input/output token indicators.
111 fn supports_split_token_display(&self) -> bool {
112 false
113 }
114
115 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
116 LanguageModelToolSchemaFormat::JsonSchema
117 }
118
119 fn max_token_count(&self) -> u64;
120 fn max_output_tokens(&self) -> Option<u64> {
121 None
122 }
123
124 fn stream_completion(
125 &self,
126 request: LanguageModelRequest,
127 cx: &AsyncApp,
128 ) -> BoxFuture<
129 'static,
130 Result<
131 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
132 LanguageModelCompletionError,
133 >,
134 >;
135
136 fn stream_completion_text(
137 &self,
138 request: LanguageModelRequest,
139 cx: &AsyncApp,
140 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
141 let future = self.stream_completion(request, cx);
142
143 async move {
144 let events = future.await?;
145 let mut events = events.fuse();
146 let mut message_id = None;
147 let mut first_item_text = None;
148 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
149
150 if let Some(first_event) = events.next().await {
151 match first_event {
152 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
153 message_id = Some(id);
154 }
155 Ok(LanguageModelCompletionEvent::Text(text)) => {
156 first_item_text = Some(text);
157 }
158 _ => (),
159 }
160 }
161
162 let stream = futures::stream::iter(first_item_text.map(Ok))
163 .chain(events.filter_map({
164 let last_token_usage = last_token_usage.clone();
165 move |result| {
166 let last_token_usage = last_token_usage.clone();
167 async move {
168 match result {
169 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
170 Ok(LanguageModelCompletionEvent::Started) => None,
171 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
172 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
173 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
174 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
175 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
176 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
177 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
178 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
179 ..
180 }) => None,
181 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
182 *last_token_usage.lock() = token_usage;
183 None
184 }
185 Err(err) => Some(Err(err)),
186 }
187 }
188 }
189 }))
190 .boxed();
191
192 Ok(LanguageModelTextStream {
193 message_id,
194 stream,
195 last_token_usage,
196 })
197 }
198 .boxed()
199 }
200
201 fn stream_completion_tool(
202 &self,
203 request: LanguageModelRequest,
204 cx: &AsyncApp,
205 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
206 let future = self.stream_completion(request, cx);
207
208 async move {
209 let events = future.await?;
210 let mut events = events.fuse();
211
212 // Iterate through events until we find a complete ToolUse
213 while let Some(event) = events.next().await {
214 match event {
215 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
216 if tool_use.is_input_complete =>
217 {
218 return Ok(tool_use);
219 }
220 Err(err) => {
221 return Err(err);
222 }
223 _ => {}
224 }
225 }
226
227 // Stream ended without a complete tool use
228 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
229 "Stream ended without receiving a complete tool use"
230 )))
231 }
232 .boxed()
233 }
234
235 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
236 None
237 }
238
239 #[cfg(any(test, feature = "test-support"))]
240 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
241 unimplemented!()
242 }
243}
244
245impl std::fmt::Debug for dyn LanguageModel {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.debug_struct("<dyn LanguageModel>")
248 .field("id", &self.id())
249 .field("name", &self.name())
250 .field("provider_id", &self.provider_id())
251 .field("provider_name", &self.provider_name())
252 .field("upstream_provider_name", &self.upstream_provider_name())
253 .field("upstream_provider_id", &self.upstream_provider_id())
254 .field("upstream_provider_id", &self.upstream_provider_id())
255 .field("supports_streaming_tools", &self.supports_streaming_tools())
256 .finish()
257 }
258}
259
260/// Either a built-in icon name or a path to an external SVG.
261#[derive(Debug, Clone, PartialEq, Eq)]
262pub enum IconOrSvg {
263 /// A built-in icon from Zed's icon set.
264 Icon(IconName),
265 /// Path to a custom SVG icon file.
266 Svg(SharedString),
267}
268
269impl Default for IconOrSvg {
270 fn default() -> Self {
271 Self::Icon(IconName::ZedAssistant)
272 }
273}
274
275pub trait LanguageModelProvider: 'static {
276 fn id(&self) -> LanguageModelProviderId;
277 fn name(&self) -> LanguageModelProviderName;
278 fn icon(&self) -> IconOrSvg {
279 IconOrSvg::default()
280 }
281 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
282 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
283 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
284 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
285 Vec::new()
286 }
287 fn is_authenticated(&self, cx: &App) -> bool;
288 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
289 fn configuration_view(
290 &self,
291 target_agent: ConfigurationViewTargetAgent,
292 window: &mut Window,
293 cx: &mut App,
294 ) -> AnyView;
295 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
296}
297
298#[derive(Default, Clone, PartialEq, Eq)]
299pub enum ConfigurationViewTargetAgent {
300 #[default]
301 ZedAgent,
302 Other(SharedString),
303}
304
305pub trait LanguageModelProviderState: 'static {
306 type ObservableEntity;
307
308 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
309
310 fn subscribe<T: 'static>(
311 &self,
312 cx: &mut gpui::Context<T>,
313 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
314 ) -> Option<gpui::Subscription> {
315 let entity = self.observable_entity()?;
316 Some(cx.observe(&entity, move |this, _, cx| {
317 callback(this, cx);
318 }))
319 }
320}
321
322#[derive(Clone, Debug, PartialEq)]
323pub enum LanguageModelCostInfo {
324 /// Cost per 1,000 input and output tokens
325 TokenCost {
326 input_token_cost_per_1m: f64,
327 output_token_cost_per_1m: f64,
328 },
329 /// Cost per request
330 RequestCost { cost_per_request: f64 },
331}
332
333impl LanguageModelCostInfo {
334 pub fn to_shared_string(&self) -> SharedString {
335 match self {
336 LanguageModelCostInfo::RequestCost { cost_per_request } => {
337 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
338 SharedString::from(cost_str)
339 }
340 LanguageModelCostInfo::TokenCost {
341 input_token_cost_per_1m,
342 output_token_cost_per_1m,
343 } => {
344 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
345 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
346 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
347 }
348 }
349 }
350
351 fn cost_value_to_string(cost: &f64) -> SharedString {
352 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
353 SharedString::from(format!("{:.0}", cost))
354 } else {
355 SharedString::from(format!("{:.2}", cost))
356 }
357 }
358}