1package openaicompat
2
3import (
4 "github.com/charmbracelet/fantasy/ai"
5 "github.com/charmbracelet/fantasy/openai"
6 "github.com/openai/openai-go/v2/option"
7)
8
9type options struct {
10 openaiOptions []openai.Option
11 languageModelOptions []openai.LanguageModelOption
12}
13
14const (
15 Name = "openai-compat"
16)
17
18type Option = func(*options)
19
20func New(url string, opts ...Option) ai.Provider {
21 providerOptions := options{
22 openaiOptions: []openai.Option{
23 openai.WithName(Name),
24 openai.WithBaseURL(url),
25 },
26 languageModelOptions: []openai.LanguageModelOption{
27 openai.WithLanguageModelPrepareCallFunc(languagePrepareModelCall),
28 openai.WithLanguageModelStreamExtraFunc(languageModelStreamExtra),
29 openai.WithLanguageModelExtraContentFunc(languageModelExtraContent),
30 },
31 }
32 for _, o := range opts {
33 o(&providerOptions)
34 }
35
36 providerOptions.openaiOptions = append(providerOptions.openaiOptions, openai.WithLanguageModelOptions(providerOptions.languageModelOptions...))
37 return openai.New(providerOptions.openaiOptions...)
38}
39
40func WithAPIKey(apiKey string) Option {
41 return func(o *options) {
42 o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey))
43 }
44}
45
46func WithName(name string) Option {
47 return func(o *options) {
48 o.openaiOptions = append(o.openaiOptions, openai.WithName(name))
49 }
50}
51
52func WithHeaders(headers map[string]string) Option {
53 return func(o *options) {
54 o.openaiOptions = append(o.openaiOptions, openai.WithHeaders(headers))
55 }
56}
57
58func WithHTTPClient(client option.HTTPClient) Option {
59 return func(o *options) {
60 o.openaiOptions = append(o.openaiOptions, openai.WithHTTPClient(client))
61 }
62}
63
64func WithLanguageUniqueToolCallIds() Option {
65 return func(l *options) {
66 l.languageModelOptions = append(l.languageModelOptions, openai.WithLanguageUniqueToolCallIds())
67 }
68}
69
70func WithLanguageModelGenerateIDFunc(fn openai.LanguageModelGenerateIDFunc) Option {
71 return func(l *options) {
72 l.languageModelOptions = append(l.languageModelOptions, openai.WithLanguageModelGenerateIDFunc(fn))
73 }
74}