1mod model;
2pub mod provider;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7pub mod settings;
8
9use anyhow::Result;
10use client::Client;
11use futures::{future::BoxFuture, stream::BoxStream};
12use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
13pub use model::*;
14use project::Fs;
15pub(crate) use rate_limiter::*;
16pub use registry::*;
17pub use request::*;
18pub use role::*;
19use schemars::JsonSchema;
20use serde::de::DeserializeOwned;
21use std::{future::Future, sync::Arc};
22
23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
24 settings::init(fs, cx);
25 registry::init(client, cx);
26}
27
28pub trait LanguageModel: Send + Sync {
29 fn id(&self) -> LanguageModelId;
30 fn name(&self) -> LanguageModelName;
31 fn provider_id(&self) -> LanguageModelProviderId;
32 fn provider_name(&self) -> LanguageModelProviderName;
33 fn telemetry_id(&self) -> String;
34
35 fn max_token_count(&self) -> usize;
36
37 fn count_tokens(
38 &self,
39 request: LanguageModelRequest,
40 cx: &AppContext,
41 ) -> BoxFuture<'static, Result<usize>>;
42
43 fn stream_completion(
44 &self,
45 request: LanguageModelRequest,
46 cx: &AsyncAppContext,
47 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
48
49 fn use_any_tool(
50 &self,
51 request: LanguageModelRequest,
52 name: String,
53 description: String,
54 schema: serde_json::Value,
55 cx: &AsyncAppContext,
56 ) -> BoxFuture<'static, Result<serde_json::Value>>;
57}
58
59impl dyn LanguageModel {
60 pub fn use_tool<T: LanguageModelTool>(
61 &self,
62 request: LanguageModelRequest,
63 cx: &AsyncAppContext,
64 ) -> impl 'static + Future<Output = Result<T>> {
65 let schema = schemars::schema_for!(T);
66 let schema_json = serde_json::to_value(&schema).unwrap();
67 let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
68 async move {
69 let response = request.await?;
70 Ok(serde_json::from_value(response)?)
71 }
72 }
73}
74
75pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
76 fn name() -> String;
77 fn description() -> String;
78}
79
80pub trait LanguageModelProvider: 'static {
81 fn id(&self) -> LanguageModelProviderId;
82 fn name(&self) -> LanguageModelProviderName;
83 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
84 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
85 fn is_authenticated(&self, cx: &AppContext) -> bool;
86 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
87 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
88 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
89}
90
91pub trait LanguageModelProviderState: 'static {
92 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
93}
94
95#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
96pub struct LanguageModelId(pub SharedString);
97
98#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
99pub struct LanguageModelName(pub SharedString);
100
101#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
102pub struct LanguageModelProviderId(pub SharedString);
103
104#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
105pub struct LanguageModelProviderName(pub SharedString);
106
107impl From<String> for LanguageModelId {
108 fn from(value: String) -> Self {
109 Self(SharedString::from(value))
110 }
111}
112
113impl From<String> for LanguageModelName {
114 fn from(value: String) -> Self {
115 Self(SharedString::from(value))
116 }
117}
118
119impl From<String> for LanguageModelProviderId {
120 fn from(value: String) -> Self {
121 Self(SharedString::from(value))
122 }
123}
124
125impl From<String> for LanguageModelProviderName {
126 fn from(value: String) -> Self {
127 Self(SharedString::from(value))
128 }
129}