1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anyhow::{Context as _, Result};
12use client::Client;
13use futures::FutureExt;
14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
16use http_client::http::{HeaderMap, HeaderValue};
17use icons::IconName;
18use parking_lot::Mutex;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize, de::DeserializeOwned};
21use std::fmt;
22use std::ops::{Add, Sub};
23use std::str::FromStr as _;
24use std::sync::Arc;
25use thiserror::Error;
26use util::serde::is_default;
27use zed_llm_client::{
28 CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
29 MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
30};
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
40
41pub fn init(client: Arc<Client>, cx: &mut App) {
42 init_settings(cx);
43 RefreshLlmTokenListener::register(client.clone(), cx);
44}
45
46pub fn init_settings(cx: &mut App) {
47 registry::init(cx);
48}
49
50/// Configuration for caching language model messages.
51#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
52pub struct LanguageModelCacheConfiguration {
53 pub max_cache_anchors: usize,
54 pub should_speculate: bool,
55 pub min_total_token: usize,
56}
57
58/// A completion event from a language model.
59#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
60pub enum LanguageModelCompletionEvent {
61 StatusUpdate(CompletionRequestStatus),
62 Stop(StopReason),
63 Text(String),
64 Thinking {
65 text: String,
66 signature: Option<String>,
67 },
68 ToolUse(LanguageModelToolUse),
69 StartMessage {
70 message_id: String,
71 },
72 UsageUpdate(TokenUsage),
73}
74
75#[derive(Error, Debug)]
76pub enum LanguageModelCompletionError {
77 #[error("received bad input JSON")]
78 BadInputJson {
79 id: LanguageModelToolUseId,
80 tool_name: Arc<str>,
81 raw_input: Arc<str>,
82 json_parse_error: String,
83 },
84 #[error(transparent)]
85 Other(#[from] anyhow::Error),
86}
87
88/// Indicates the format used to define the input schema for a language model tool.
89#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
90pub enum LanguageModelToolSchemaFormat {
91 /// A JSON schema, see https://json-schema.org
92 JsonSchema,
93 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
94 JsonSchemaSubset,
95}
96
97#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
98#[serde(rename_all = "snake_case")]
99pub enum StopReason {
100 EndTurn,
101 MaxTokens,
102 ToolUse,
103 Refusal,
104}
105
106#[derive(Debug, Clone, Copy)]
107pub struct RequestUsage {
108 pub limit: UsageLimit,
109 pub amount: i32,
110}
111
112impl RequestUsage {
113 pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
114 let limit = headers
115 .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
116 .with_context(|| {
117 format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header")
118 })?;
119 let limit = UsageLimit::from_str(limit.to_str()?)?;
120
121 let amount = headers
122 .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
123 .with_context(|| {
124 format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header")
125 })?;
126 let amount = amount.to_str()?.parse::<i32>()?;
127
128 Ok(Self { limit, amount })
129 }
130}
131
132#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
133pub struct TokenUsage {
134 #[serde(default, skip_serializing_if = "is_default")]
135 pub input_tokens: u32,
136 #[serde(default, skip_serializing_if = "is_default")]
137 pub output_tokens: u32,
138 #[serde(default, skip_serializing_if = "is_default")]
139 pub cache_creation_input_tokens: u32,
140 #[serde(default, skip_serializing_if = "is_default")]
141 pub cache_read_input_tokens: u32,
142}
143
144impl TokenUsage {
145 pub fn total_tokens(&self) -> u32 {
146 self.input_tokens
147 + self.output_tokens
148 + self.cache_read_input_tokens
149 + self.cache_creation_input_tokens
150 }
151}
152
153impl Add<TokenUsage> for TokenUsage {
154 type Output = Self;
155
156 fn add(self, other: Self) -> Self {
157 Self {
158 input_tokens: self.input_tokens + other.input_tokens,
159 output_tokens: self.output_tokens + other.output_tokens,
160 cache_creation_input_tokens: self.cache_creation_input_tokens
161 + other.cache_creation_input_tokens,
162 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
163 }
164 }
165}
166
167impl Sub<TokenUsage> for TokenUsage {
168 type Output = Self;
169
170 fn sub(self, other: Self) -> Self {
171 Self {
172 input_tokens: self.input_tokens - other.input_tokens,
173 output_tokens: self.output_tokens - other.output_tokens,
174 cache_creation_input_tokens: self.cache_creation_input_tokens
175 - other.cache_creation_input_tokens,
176 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
177 }
178 }
179}
180
181#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
182pub struct LanguageModelToolUseId(Arc<str>);
183
184impl fmt::Display for LanguageModelToolUseId {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 write!(f, "{}", self.0)
187 }
188}
189
190impl<T> From<T> for LanguageModelToolUseId
191where
192 T: Into<Arc<str>>,
193{
194 fn from(value: T) -> Self {
195 Self(value.into())
196 }
197}
198
199#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
200pub struct LanguageModelToolUse {
201 pub id: LanguageModelToolUseId,
202 pub name: Arc<str>,
203 pub raw_input: String,
204 pub input: serde_json::Value,
205 pub is_input_complete: bool,
206}
207
208pub struct LanguageModelTextStream {
209 pub message_id: Option<String>,
210 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
211 // Has complete token usage after the stream has finished
212 pub last_token_usage: Arc<Mutex<TokenUsage>>,
213}
214
215impl Default for LanguageModelTextStream {
216 fn default() -> Self {
217 Self {
218 message_id: None,
219 stream: Box::pin(futures::stream::empty()),
220 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
221 }
222 }
223}
224
225pub trait LanguageModel: Send + Sync {
226 fn id(&self) -> LanguageModelId;
227 fn name(&self) -> LanguageModelName;
228 fn provider_id(&self) -> LanguageModelProviderId;
229 fn provider_name(&self) -> LanguageModelProviderName;
230 fn telemetry_id(&self) -> String;
231
232 fn api_key(&self, _cx: &App) -> Option<String> {
233 None
234 }
235
236 /// Whether this model supports images
237 fn supports_images(&self) -> bool;
238
239 /// Whether this model supports tools.
240 fn supports_tools(&self) -> bool;
241
242 /// Whether this model supports choosing which tool to use.
243 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
244
245 /// Returns whether this model supports "burn mode";
246 fn supports_max_mode(&self) -> bool {
247 false
248 }
249
250 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
251 LanguageModelToolSchemaFormat::JsonSchema
252 }
253
254 fn max_token_count(&self) -> usize;
255 fn max_output_tokens(&self) -> Option<u32> {
256 None
257 }
258
259 fn count_tokens(
260 &self,
261 request: LanguageModelRequest,
262 cx: &App,
263 ) -> BoxFuture<'static, Result<usize>>;
264
265 fn stream_completion(
266 &self,
267 request: LanguageModelRequest,
268 cx: &AsyncApp,
269 ) -> BoxFuture<
270 'static,
271 Result<
272 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
273 >,
274 >;
275
276 fn stream_completion_text(
277 &self,
278 request: LanguageModelRequest,
279 cx: &AsyncApp,
280 ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
281 let future = self.stream_completion(request, cx);
282
283 async move {
284 let events = future.await?;
285 let mut events = events.fuse();
286 let mut message_id = None;
287 let mut first_item_text = None;
288 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
289
290 if let Some(first_event) = events.next().await {
291 match first_event {
292 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
293 message_id = Some(id.clone());
294 }
295 Ok(LanguageModelCompletionEvent::Text(text)) => {
296 first_item_text = Some(text);
297 }
298 _ => (),
299 }
300 }
301
302 let stream = futures::stream::iter(first_item_text.map(Ok))
303 .chain(events.filter_map({
304 let last_token_usage = last_token_usage.clone();
305 move |result| {
306 let last_token_usage = last_token_usage.clone();
307 async move {
308 match result {
309 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
310 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
311 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
312 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
313 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
314 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
315 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
316 *last_token_usage.lock() = token_usage;
317 None
318 }
319 Err(err) => Some(Err(err)),
320 }
321 }
322 }
323 }))
324 .boxed();
325
326 Ok(LanguageModelTextStream {
327 message_id,
328 stream,
329 last_token_usage,
330 })
331 }
332 .boxed()
333 }
334
335 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
336 None
337 }
338
339 #[cfg(any(test, feature = "test-support"))]
340 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
341 unimplemented!()
342 }
343}
344
345#[derive(Debug, Error)]
346pub enum LanguageModelKnownError {
347 #[error("Context window limit exceeded ({tokens})")]
348 ContextWindowLimitExceeded { tokens: usize },
349}
350
351pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
352 fn name() -> String;
353 fn description() -> String;
354}
355
356/// An error that occurred when trying to authenticate the language model provider.
357#[derive(Debug, Error)]
358pub enum AuthenticateError {
359 #[error("credentials not found")]
360 CredentialsNotFound,
361 #[error(transparent)]
362 Other(#[from] anyhow::Error),
363}
364
365pub trait LanguageModelProvider: 'static {
366 fn id(&self) -> LanguageModelProviderId;
367 fn name(&self) -> LanguageModelProviderName;
368 fn icon(&self) -> IconName {
369 IconName::ZedAssistant
370 }
371 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
372 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
373 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
374 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
375 Vec::new()
376 }
377 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
378 fn is_authenticated(&self, cx: &App) -> bool;
379 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
380 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
381 fn must_accept_terms(&self, _cx: &App) -> bool {
382 false
383 }
384 fn render_accept_terms(
385 &self,
386 _view: LanguageModelProviderTosView,
387 _cx: &mut App,
388 ) -> Option<AnyElement> {
389 None
390 }
391 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
392}
393
394#[derive(PartialEq, Eq)]
395pub enum LanguageModelProviderTosView {
396 /// When there are some past interactions in the Agent Panel.
397 ThreadtEmptyState,
398 /// When there are no past interactions in the Agent Panel.
399 ThreadFreshStart,
400 PromptEditorPopup,
401 Configuration,
402}
403
404pub trait LanguageModelProviderState: 'static {
405 type ObservableEntity;
406
407 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
408
409 fn subscribe<T: 'static>(
410 &self,
411 cx: &mut gpui::Context<T>,
412 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
413 ) -> Option<gpui::Subscription> {
414 let entity = self.observable_entity()?;
415 Some(cx.observe(&entity, move |this, _, cx| {
416 callback(this, cx);
417 }))
418 }
419}
420
421#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
422pub struct LanguageModelId(pub SharedString);
423
424#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
425pub struct LanguageModelName(pub SharedString);
426
427#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
428pub struct LanguageModelProviderId(pub SharedString);
429
430#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
431pub struct LanguageModelProviderName(pub SharedString);
432
433impl fmt::Display for LanguageModelProviderId {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 write!(f, "{}", self.0)
436 }
437}
438
439impl From<String> for LanguageModelId {
440 fn from(value: String) -> Self {
441 Self(SharedString::from(value))
442 }
443}
444
445impl From<String> for LanguageModelName {
446 fn from(value: String) -> Self {
447 Self(SharedString::from(value))
448 }
449}
450
451impl From<String> for LanguageModelProviderId {
452 fn from(value: String) -> Self {
453 Self(SharedString::from(value))
454 }
455}
456
457impl From<String> for LanguageModelProviderName {
458 fn from(value: String) -> Self {
459 Self(SharedString::from(value))
460 }
461}