1mod model;
2pub mod provider;
3mod registry;
4mod request;
5mod role;
6pub mod settings;
7
8use std::sync::Arc;
9
10use anyhow::Result;
11use client::Client;
12use futures::{future::BoxFuture, stream::BoxStream};
13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
14
15pub use model::*;
16pub use registry::*;
17pub use request::*;
18pub use role::*;
19use schemars::JsonSchema;
20use serde::de::DeserializeOwned;
21
22pub fn init(client: Arc<Client>, cx: &mut AppContext) {
23 settings::init(cx);
24 registry::init(client, cx);
25}
26
27pub trait LanguageModel: Send + Sync {
28 fn id(&self) -> LanguageModelId;
29 fn name(&self) -> LanguageModelName;
30 fn provider_id(&self) -> LanguageModelProviderId;
31 fn provider_name(&self) -> LanguageModelProviderName;
32 fn telemetry_id(&self) -> String;
33
34 fn max_token_count(&self) -> usize;
35
36 fn count_tokens(
37 &self,
38 request: LanguageModelRequest,
39 cx: &AppContext,
40 ) -> BoxFuture<'static, Result<usize>>;
41
42 fn stream_completion(
43 &self,
44 request: LanguageModelRequest,
45 cx: &AsyncAppContext,
46 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
47
48 fn use_tool(
49 &self,
50 request: LanguageModelRequest,
51 name: String,
52 description: String,
53 schema: serde_json::Value,
54 cx: &AsyncAppContext,
55 ) -> BoxFuture<'static, Result<serde_json::Value>>;
56}
57
58pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
59 fn name() -> String;
60 fn description() -> String;
61}
62
63pub trait LanguageModelProvider: 'static {
64 fn id(&self) -> LanguageModelProviderId;
65 fn name(&self) -> LanguageModelProviderName;
66 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
67 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
68 fn is_authenticated(&self, cx: &AppContext) -> bool;
69 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
70 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
71 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
72}
73
74pub trait LanguageModelProviderState: 'static {
75 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
76}
77
78#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
79pub struct LanguageModelId(pub SharedString);
80
81#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
82pub struct LanguageModelName(pub SharedString);
83
84#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
85pub struct LanguageModelProviderId(pub SharedString);
86
87#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
88pub struct LanguageModelProviderName(pub SharedString);
89
90impl From<String> for LanguageModelId {
91 fn from(value: String) -> Self {
92 Self(SharedString::from(value))
93 }
94}
95
96impl From<String> for LanguageModelName {
97 fn from(value: String) -> Self {
98 Self(SharedString::from(value))
99 }
100}
101
102impl From<String> for LanguageModelProviderId {
103 fn from(value: String) -> Self {
104 Self(SharedString::from(value))
105 }
106}
107
108impl From<String> for LanguageModelProviderName {
109 fn from(value: String) -> Self {
110 Self(SharedString::from(value))
111 }
112}