1mod model;
2pub mod provider;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7pub mod settings;
8
9use anyhow::Result;
10use client::Client;
11use futures::{future::BoxFuture, stream::BoxStream};
12use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
13pub use model::*;
14use project::Fs;
15pub(crate) use rate_limiter::*;
16pub use registry::*;
17pub use request::*;
18pub use role::*;
19use schemars::JsonSchema;
20use serde::de::DeserializeOwned;
21use std::{future::Future, sync::Arc};
22
23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
24 settings::init(fs, cx);
25 registry::init(client, cx);
26}
27
28pub trait LanguageModel: Send + Sync {
29 fn id(&self) -> LanguageModelId;
30 fn name(&self) -> LanguageModelName;
31 fn provider_id(&self) -> LanguageModelProviderId;
32 fn provider_name(&self) -> LanguageModelProviderName;
33 fn telemetry_id(&self) -> String;
34
35 fn max_token_count(&self) -> usize;
36
37 fn count_tokens(
38 &self,
39 request: LanguageModelRequest,
40 cx: &AppContext,
41 ) -> BoxFuture<'static, Result<usize>>;
42
43 fn stream_completion(
44 &self,
45 request: LanguageModelRequest,
46 cx: &AsyncAppContext,
47 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
48
49 fn use_any_tool(
50 &self,
51 request: LanguageModelRequest,
52 name: String,
53 description: String,
54 schema: serde_json::Value,
55 cx: &AsyncAppContext,
56 ) -> BoxFuture<'static, Result<serde_json::Value>>;
57}
58
59impl dyn LanguageModel {
60 pub fn use_tool<T: LanguageModelTool>(
61 &self,
62 request: LanguageModelRequest,
63 cx: &AsyncAppContext,
64 ) -> impl 'static + Future<Output = Result<T>> {
65 let schema = schemars::schema_for!(T);
66 let schema_json = serde_json::to_value(&schema).unwrap();
67 let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
68 async move {
69 let response = request.await?;
70 Ok(serde_json::from_value(response)?)
71 }
72 }
73}
74
75pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
76 fn name() -> String;
77 fn description() -> String;
78}
79
80pub trait LanguageModelProvider: 'static {
81 fn id(&self) -> LanguageModelProviderId;
82 fn name(&self) -> LanguageModelProviderName;
83 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
84 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
85 fn is_authenticated(&self, cx: &AppContext) -> bool;
86 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
87 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
88 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
89}
90
91pub trait LanguageModelProviderState: 'static {
92 type ObservableEntity;
93
94 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
95
96 fn subscribe<T: 'static>(
97 &self,
98 cx: &mut gpui::ModelContext<T>,
99 callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
100 ) -> Option<gpui::Subscription> {
101 let entity = self.observable_entity()?;
102 Some(cx.observe(&entity, move |this, _, cx| {
103 callback(this, cx);
104 }))
105 }
106}
107
108#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
109pub struct LanguageModelId(pub SharedString);
110
111#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
112pub struct LanguageModelName(pub SharedString);
113
114#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
115pub struct LanguageModelProviderId(pub SharedString);
116
117#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
118pub struct LanguageModelProviderName(pub SharedString);
119
120impl From<String> for LanguageModelId {
121 fn from(value: String) -> Self {
122 Self(SharedString::from(value))
123 }
124}
125
126impl From<String> for LanguageModelName {
127 fn from(value: String) -> Self {
128 Self(SharedString::from(value))
129 }
130}
131
132impl From<String> for LanguageModelProviderId {
133 fn from(value: String) -> Self {
134 Self(SharedString::from(value))
135 }
136}
137
138impl From<String> for LanguageModelProviderName {
139 fn from(value: String) -> Self {
140 Self(SharedString::from(value))
141 }
142}