1mod model;
2pub mod provider;
3mod registry;
4mod request;
5mod role;
6pub mod settings;
7
8use std::sync::Arc;
9
10use anyhow::Result;
11use client::Client;
12use futures::{future::BoxFuture, stream::BoxStream};
13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
14
15pub use model::*;
16pub use registry::*;
17pub use request::*;
18pub use role::*;
19
20pub fn init(client: Arc<Client>, cx: &mut AppContext) {
21 settings::init(cx);
22 registry::init(client, cx);
23}
24
25pub trait LanguageModel: Send + Sync {
26 fn id(&self) -> LanguageModelId;
27 fn name(&self) -> LanguageModelName;
28 fn provider_id(&self) -> LanguageModelProviderId;
29 fn provider_name(&self) -> LanguageModelProviderName;
30 fn telemetry_id(&self) -> String;
31
32 fn max_token_count(&self) -> usize;
33
34 fn count_tokens(
35 &self,
36 request: LanguageModelRequest,
37 cx: &AppContext,
38 ) -> BoxFuture<'static, Result<usize>>;
39
40 fn stream_completion(
41 &self,
42 request: LanguageModelRequest,
43 cx: &AsyncAppContext,
44 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
45}
46
47pub trait LanguageModelProvider: 'static {
48 fn id(&self) -> LanguageModelProviderId;
49 fn name(&self) -> LanguageModelProviderName;
50 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
51 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
52 fn is_authenticated(&self, cx: &AppContext) -> bool;
53 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
54 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
55 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
56}
57
58pub trait LanguageModelProviderState: 'static {
59 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
60}
61
62#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
63pub struct LanguageModelId(pub SharedString);
64
65#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
66pub struct LanguageModelName(pub SharedString);
67
68#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
69pub struct LanguageModelProviderId(pub SharedString);
70
71#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
72pub struct LanguageModelProviderName(pub SharedString);
73
74impl From<String> for LanguageModelId {
75 fn from(value: String) -> Self {
76 Self(SharedString::from(value))
77 }
78}
79
80impl From<String> for LanguageModelName {
81 fn from(value: String) -> Self {
82 Self(SharedString::from(value))
83 }
84}
85
86impl From<String> for LanguageModelProviderId {
87 fn from(value: String) -> Self {
88 Self(SharedString::from(value))
89 }
90}
91
92impl From<String> for LanguageModelProviderName {
93 fn from(value: String) -> Self {
94 Self(SharedString::from(value))
95 }
96}