language_model.rs

 1mod model;
 2pub mod provider;
 3mod registry;
 4mod request;
 5mod role;
 6pub mod settings;
 7
 8use std::sync::Arc;
 9
10use anyhow::Result;
11use client::Client;
12use futures::{future::BoxFuture, stream::BoxStream};
13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
14
15pub use model::*;
16pub use registry::*;
17pub use request::*;
18pub use role::*;
19
20pub fn init(client: Arc<Client>, cx: &mut AppContext) {
21    settings::init(cx);
22    registry::init(client, cx);
23}
24
25pub trait LanguageModel: Send + Sync {
26    fn id(&self) -> LanguageModelId;
27    fn name(&self) -> LanguageModelName;
28    fn provider_name(&self) -> LanguageModelProviderName;
29    fn telemetry_id(&self) -> String;
30
31    fn max_token_count(&self) -> usize;
32
33    fn count_tokens(
34        &self,
35        request: LanguageModelRequest,
36        cx: &AppContext,
37    ) -> BoxFuture<'static, Result<usize>>;
38
39    fn stream_completion(
40        &self,
41        request: LanguageModelRequest,
42        cx: &AsyncAppContext,
43    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
44}
45
46pub trait LanguageModelProvider: 'static {
47    fn name(&self) -> LanguageModelProviderName;
48    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
49    fn is_authenticated(&self, cx: &AppContext) -> bool;
50    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
51    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
52    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
53}
54
55pub trait LanguageModelProviderState: 'static {
56    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
57}
58
59#[derive(Clone, Eq, PartialEq, Hash, Debug)]
60pub struct LanguageModelId(pub SharedString);
61
62#[derive(Clone, Eq, PartialEq, Hash, Debug)]
63pub struct LanguageModelName(pub SharedString);
64
65#[derive(Clone, Eq, PartialEq, Hash, Debug)]
66pub struct LanguageModelProviderName(pub SharedString);
67
68impl From<String> for LanguageModelId {
69    fn from(value: String) -> Self {
70        Self(SharedString::from(value))
71    }
72}
73
74impl From<String> for LanguageModelName {
75    fn from(value: String) -> Self {
76        Self(SharedString::from(value))
77    }
78}
79
80impl From<String> for LanguageModelProviderName {
81    fn from(value: String) -> Self {
82        Self(SharedString::from(value))
83    }
84}