language_model.rs

 1mod model;
 2pub mod provider;
 3mod registry;
 4mod request;
 5mod role;
 6pub mod settings;
 7
 8use std::sync::Arc;
 9
10use anyhow::Result;
11use client::Client;
12use futures::{future::BoxFuture, stream::BoxStream};
13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
14
15pub use model::*;
16pub use registry::*;
17pub use request::*;
18pub use role::*;
19
20pub fn init(client: Arc<Client>, cx: &mut AppContext) {
21    settings::init(cx);
22    registry::init(client, cx);
23}
24
25pub trait LanguageModel: Send + Sync {
26    fn id(&self) -> LanguageModelId;
27    fn name(&self) -> LanguageModelName;
28    fn provider_id(&self) -> LanguageModelProviderId;
29    fn provider_name(&self) -> LanguageModelProviderName;
30    fn telemetry_id(&self) -> String;
31
32    fn max_token_count(&self) -> usize;
33
34    fn count_tokens(
35        &self,
36        request: LanguageModelRequest,
37        cx: &AppContext,
38    ) -> BoxFuture<'static, Result<usize>>;
39
40    fn stream_completion(
41        &self,
42        request: LanguageModelRequest,
43        cx: &AsyncAppContext,
44    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
45}
46
47pub trait LanguageModelProvider: 'static {
48    fn id(&self) -> LanguageModelProviderId;
49    fn name(&self) -> LanguageModelProviderName;
50    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
51    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
52    fn is_authenticated(&self, cx: &AppContext) -> bool;
53    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
54    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
55    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
56}
57
58pub trait LanguageModelProviderState: 'static {
59    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
60}
61
62#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
63pub struct LanguageModelId(pub SharedString);
64
65#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
66pub struct LanguageModelName(pub SharedString);
67
68#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
69pub struct LanguageModelProviderId(pub SharedString);
70
71#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
72pub struct LanguageModelProviderName(pub SharedString);
73
74impl From<String> for LanguageModelId {
75    fn from(value: String) -> Self {
76        Self(SharedString::from(value))
77    }
78}
79
80impl From<String> for LanguageModelName {
81    fn from(value: String) -> Self {
82        Self(SharedString::from(value))
83    }
84}
85
86impl From<String> for LanguageModelProviderId {
87    fn from(value: String) -> Self {
88        Self(SharedString::from(value))
89    }
90}
91
92impl From<String> for LanguageModelProviderName {
93    fn from(value: String) -> Self {
94        Self(SharedString::from(value))
95    }
96}