1package openaicompat
2
3import (
4 "encoding/json"
5
6 "github.com/charmbracelet/fantasy/ai"
7 "github.com/charmbracelet/fantasy/openai"
8 "github.com/openai/openai-go/v2/option"
9)
10
11type options struct {
12 openaiOptions []openai.Option
13 languageModelOptions []openai.LanguageModelOption
14}
15
16const (
17 Name = "openai-compat"
18)
19
20type Option = func(*options)
21
22func New(url string, opts ...Option) ai.Provider {
23 providerOptions := options{
24 openaiOptions: []openai.Option{
25 openai.WithName(Name),
26 openai.WithBaseURL(url),
27 },
28 languageModelOptions: []openai.LanguageModelOption{
29 openai.WithLanguageModelPrepareCallFunc(languagePrepareModelCall),
30 // openai.WithLanguageModelStreamExtraFunc(languageModelStreamExtra),
31 // openai.WithLanguageModelExtraContentFunc(languageModelExtraContent),
32 },
33 }
34 for _, o := range opts {
35 o(&providerOptions)
36 }
37
38 providerOptions.openaiOptions = append(providerOptions.openaiOptions, openai.WithLanguageModelOptions(providerOptions.languageModelOptions...))
39 return openai.New(providerOptions.openaiOptions...)
40}
41
42func WithAPIKey(apiKey string) Option {
43 return func(o *options) {
44 o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey))
45 }
46}
47
48func WithName(name string) Option {
49 return func(o *options) {
50 o.openaiOptions = append(o.openaiOptions, openai.WithName(name))
51 }
52}
53
54func WithHeaders(headers map[string]string) Option {
55 return func(o *options) {
56 o.openaiOptions = append(o.openaiOptions, openai.WithHeaders(headers))
57 }
58}
59
60func WithHTTPClient(client option.HTTPClient) Option {
61 return func(o *options) {
62 o.openaiOptions = append(o.openaiOptions, openai.WithHTTPClient(client))
63 }
64}
65
66func WithLanguageUniqueToolCallIds() Option {
67 return func(l *options) {
68 l.languageModelOptions = append(l.languageModelOptions, openai.WithLanguageUniqueToolCallIds())
69 }
70}
71
72func WithLanguageModelGenerateIDFunc(fn openai.LanguageModelGenerateIDFunc) Option {
73 return func(l *options) {
74 l.languageModelOptions = append(l.languageModelOptions, openai.WithLanguageModelGenerateIDFunc(fn))
75 }
76}
77
78func structToMapJSON(s any) (map[string]any, error) {
79 var result map[string]any
80 jsonBytes, err := json.Marshal(s)
81 if err != nil {
82 return nil, err
83 }
84 err = json.Unmarshal(jsonBytes, &result)
85 if err != nil {
86 return nil, err
87 }
88 return result, nil
89}