1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anyhow::Result;
12use client::Client;
13use futures::FutureExt;
14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
16use icons::IconName;
17use parking_lot::Mutex;
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize, de::DeserializeOwned};
20use std::fmt;
21use std::ops::{Add, Sub};
22use std::sync::Arc;
23use std::time::Duration;
24use thiserror::Error;
25use util::serde::is_default;
26use zed_llm_client::CompletionRequestStatus;
27
28pub use crate::model::*;
29pub use crate::rate_limiter::*;
30pub use crate::registry::*;
31pub use crate::request::*;
32pub use crate::role::*;
33pub use crate::telemetry::*;
34
35pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
36
37pub fn init(client: Arc<Client>, cx: &mut App) {
38 init_settings(cx);
39 RefreshLlmTokenListener::register(client.clone(), cx);
40}
41
42pub fn init_settings(cx: &mut App) {
43 registry::init(cx);
44}
45
46/// Configuration for caching language model messages.
47#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
48pub struct LanguageModelCacheConfiguration {
49 pub max_cache_anchors: usize,
50 pub should_speculate: bool,
51 pub min_total_token: u64,
52}
53
54/// A completion event from a language model.
55#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
56pub enum LanguageModelCompletionEvent {
57 StatusUpdate(CompletionRequestStatus),
58 Stop(StopReason),
59 Text(String),
60 Thinking {
61 text: String,
62 signature: Option<String>,
63 },
64 ToolUse(LanguageModelToolUse),
65 StartMessage {
66 message_id: String,
67 },
68 UsageUpdate(TokenUsage),
69}
70
71#[derive(Error, Debug)]
72pub enum LanguageModelCompletionError {
73 #[error("rate limit exceeded, retry after {0:?}")]
74 RateLimit(Duration),
75 #[error("received bad input JSON")]
76 BadInputJson {
77 id: LanguageModelToolUseId,
78 tool_name: Arc<str>,
79 raw_input: Arc<str>,
80 json_parse_error: String,
81 },
82 #[error(transparent)]
83 Other(#[from] anyhow::Error),
84}
85
86/// Indicates the format used to define the input schema for a language model tool.
87#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
88pub enum LanguageModelToolSchemaFormat {
89 /// A JSON schema, see https://json-schema.org
90 JsonSchema,
91 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
92 JsonSchemaSubset,
93}
94
95#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
96#[serde(rename_all = "snake_case")]
97pub enum StopReason {
98 EndTurn,
99 MaxTokens,
100 ToolUse,
101 Refusal,
102}
103
104#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
105pub struct TokenUsage {
106 #[serde(default, skip_serializing_if = "is_default")]
107 pub input_tokens: u64,
108 #[serde(default, skip_serializing_if = "is_default")]
109 pub output_tokens: u64,
110 #[serde(default, skip_serializing_if = "is_default")]
111 pub cache_creation_input_tokens: u64,
112 #[serde(default, skip_serializing_if = "is_default")]
113 pub cache_read_input_tokens: u64,
114}
115
116impl TokenUsage {
117 pub fn total_tokens(&self) -> u64 {
118 self.input_tokens
119 + self.output_tokens
120 + self.cache_read_input_tokens
121 + self.cache_creation_input_tokens
122 }
123}
124
125impl Add<TokenUsage> for TokenUsage {
126 type Output = Self;
127
128 fn add(self, other: Self) -> Self {
129 Self {
130 input_tokens: self.input_tokens + other.input_tokens,
131 output_tokens: self.output_tokens + other.output_tokens,
132 cache_creation_input_tokens: self.cache_creation_input_tokens
133 + other.cache_creation_input_tokens,
134 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
135 }
136 }
137}
138
139impl Sub<TokenUsage> for TokenUsage {
140 type Output = Self;
141
142 fn sub(self, other: Self) -> Self {
143 Self {
144 input_tokens: self.input_tokens - other.input_tokens,
145 output_tokens: self.output_tokens - other.output_tokens,
146 cache_creation_input_tokens: self.cache_creation_input_tokens
147 - other.cache_creation_input_tokens,
148 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
149 }
150 }
151}
152
153#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
154pub struct LanguageModelToolUseId(Arc<str>);
155
156impl fmt::Display for LanguageModelToolUseId {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 write!(f, "{}", self.0)
159 }
160}
161
162impl<T> From<T> for LanguageModelToolUseId
163where
164 T: Into<Arc<str>>,
165{
166 fn from(value: T) -> Self {
167 Self(value.into())
168 }
169}
170
171#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
172pub struct LanguageModelToolUse {
173 pub id: LanguageModelToolUseId,
174 pub name: Arc<str>,
175 pub raw_input: String,
176 pub input: serde_json::Value,
177 pub is_input_complete: bool,
178}
179
180pub struct LanguageModelTextStream {
181 pub message_id: Option<String>,
182 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
183 // Has complete token usage after the stream has finished
184 pub last_token_usage: Arc<Mutex<TokenUsage>>,
185}
186
187impl Default for LanguageModelTextStream {
188 fn default() -> Self {
189 Self {
190 message_id: None,
191 stream: Box::pin(futures::stream::empty()),
192 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
193 }
194 }
195}
196
197pub trait LanguageModel: Send + Sync {
198 fn id(&self) -> LanguageModelId;
199 fn name(&self) -> LanguageModelName;
200 fn provider_id(&self) -> LanguageModelProviderId;
201 fn provider_name(&self) -> LanguageModelProviderName;
202 fn telemetry_id(&self) -> String;
203
204 fn api_key(&self, _cx: &App) -> Option<String> {
205 None
206 }
207
208 /// Whether this model supports images
209 fn supports_images(&self) -> bool;
210
211 /// Whether this model supports tools.
212 fn supports_tools(&self) -> bool;
213
214 /// Whether this model supports choosing which tool to use.
215 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
216
217 /// Returns whether this model supports "burn mode";
218 fn supports_max_mode(&self) -> bool {
219 false
220 }
221
222 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
223 LanguageModelToolSchemaFormat::JsonSchema
224 }
225
226 fn max_token_count(&self) -> u64;
227 fn max_output_tokens(&self) -> Option<u64> {
228 None
229 }
230
231 fn count_tokens(
232 &self,
233 request: LanguageModelRequest,
234 cx: &App,
235 ) -> BoxFuture<'static, Result<u64>>;
236
237 fn stream_completion(
238 &self,
239 request: LanguageModelRequest,
240 cx: &AsyncApp,
241 ) -> BoxFuture<
242 'static,
243 Result<
244 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
245 LanguageModelCompletionError,
246 >,
247 >;
248
249 fn stream_completion_text(
250 &self,
251 request: LanguageModelRequest,
252 cx: &AsyncApp,
253 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
254 let future = self.stream_completion(request, cx);
255
256 async move {
257 let events = future.await?;
258 let mut events = events.fuse();
259 let mut message_id = None;
260 let mut first_item_text = None;
261 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
262
263 if let Some(first_event) = events.next().await {
264 match first_event {
265 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
266 message_id = Some(id.clone());
267 }
268 Ok(LanguageModelCompletionEvent::Text(text)) => {
269 first_item_text = Some(text);
270 }
271 _ => (),
272 }
273 }
274
275 let stream = futures::stream::iter(first_item_text.map(Ok))
276 .chain(events.filter_map({
277 let last_token_usage = last_token_usage.clone();
278 move |result| {
279 let last_token_usage = last_token_usage.clone();
280 async move {
281 match result {
282 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
283 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
284 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
285 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
286 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
287 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
288 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
289 *last_token_usage.lock() = token_usage;
290 None
291 }
292 Err(err) => Some(Err(err)),
293 }
294 }
295 }
296 }))
297 .boxed();
298
299 Ok(LanguageModelTextStream {
300 message_id,
301 stream,
302 last_token_usage,
303 })
304 }
305 .boxed()
306 }
307
308 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
309 None
310 }
311
312 #[cfg(any(test, feature = "test-support"))]
313 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
314 unimplemented!()
315 }
316}
317
318#[derive(Debug, Error)]
319pub enum LanguageModelKnownError {
320 #[error("Context window limit exceeded ({tokens})")]
321 ContextWindowLimitExceeded { tokens: u64 },
322}
323
324pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
325 fn name() -> String;
326 fn description() -> String;
327}
328
329/// An error that occurred when trying to authenticate the language model provider.
330#[derive(Debug, Error)]
331pub enum AuthenticateError {
332 #[error("credentials not found")]
333 CredentialsNotFound,
334 #[error(transparent)]
335 Other(#[from] anyhow::Error),
336}
337
338pub trait LanguageModelProvider: 'static {
339 fn id(&self) -> LanguageModelProviderId;
340 fn name(&self) -> LanguageModelProviderName;
341 fn icon(&self) -> IconName {
342 IconName::ZedAssistant
343 }
344 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
345 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
346 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
347 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
348 Vec::new()
349 }
350 fn is_authenticated(&self, cx: &App) -> bool;
351 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
352 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
353 fn must_accept_terms(&self, _cx: &App) -> bool {
354 false
355 }
356 fn render_accept_terms(
357 &self,
358 _view: LanguageModelProviderTosView,
359 _cx: &mut App,
360 ) -> Option<AnyElement> {
361 None
362 }
363 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
364}
365
366#[derive(PartialEq, Eq)]
367pub enum LanguageModelProviderTosView {
368 /// When there are some past interactions in the Agent Panel.
369 ThreadtEmptyState,
370 /// When there are no past interactions in the Agent Panel.
371 ThreadFreshStart,
372 PromptEditorPopup,
373 Configuration,
374}
375
376pub trait LanguageModelProviderState: 'static {
377 type ObservableEntity;
378
379 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
380
381 fn subscribe<T: 'static>(
382 &self,
383 cx: &mut gpui::Context<T>,
384 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
385 ) -> Option<gpui::Subscription> {
386 let entity = self.observable_entity()?;
387 Some(cx.observe(&entity, move |this, _, cx| {
388 callback(this, cx);
389 }))
390 }
391}
392
393#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
394pub struct LanguageModelId(pub SharedString);
395
396#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
397pub struct LanguageModelName(pub SharedString);
398
399#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
400pub struct LanguageModelProviderId(pub SharedString);
401
402#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
403pub struct LanguageModelProviderName(pub SharedString);
404
405impl fmt::Display for LanguageModelProviderId {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 write!(f, "{}", self.0)
408 }
409}
410
411impl From<String> for LanguageModelId {
412 fn from(value: String) -> Self {
413 Self(SharedString::from(value))
414 }
415}
416
417impl From<String> for LanguageModelName {
418 fn from(value: String) -> Self {
419 Self(SharedString::from(value))
420 }
421}
422
423impl From<String> for LanguageModelProviderId {
424 fn from(value: String) -> Self {
425 Self(SharedString::from(value))
426 }
427}
428
429impl From<String> for LanguageModelProviderName {
430 fn from(value: String) -> Self {
431 Self(SharedString::from(value))
432 }
433}