1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "strings"
9
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type bedrockOptions struct {
15 // Bedrock specific options can be added here
16}
17
18type BedrockOption func(*bedrockOptions)
19
20type bedrockClient struct {
21 providerOptions providerClientOptions
22 options bedrockOptions
23 childProvider ProviderClient
24}
25
26type BedrockClient ProviderClient
27
28func newBedrockClient(opts providerClientOptions) BedrockClient {
29 bedrockOpts := bedrockOptions{}
30 // Apply bedrock specific options if they are added in the future
31
32 // Get AWS region from environment
33 region := os.Getenv("AWS_REGION")
34 if region == "" {
35 region = os.Getenv("AWS_DEFAULT_REGION")
36 }
37
38 if region == "" {
39 region = "us-east-1" // default region
40 }
41 if len(region) < 2 {
42 return &bedrockClient{
43 providerOptions: opts,
44 options: bedrockOpts,
45 childProvider: nil, // Will cause an error when used
46 }
47 }
48
49 // Prefix the model name with region
50 regionPrefix := region[:2]
51 modelName := opts.model.APIModel
52 opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
53
54 // Determine which provider to use based on the model
55 if strings.Contains(string(opts.model.APIModel), "anthropic") {
56 // Create Anthropic client with Bedrock configuration
57 anthropicOpts := opts
58 anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
59 WithAnthropicBedrock(true),
60 WithAnthropicDisableCache(),
61 )
62 return &bedrockClient{
63 providerOptions: opts,
64 options: bedrockOpts,
65 childProvider: newAnthropicClient(anthropicOpts),
66 }
67 }
68
69 // Return client with nil childProvider if model is not supported
70 // This will cause an error when used
71 return &bedrockClient{
72 providerOptions: opts,
73 options: bedrockOpts,
74 childProvider: nil,
75 }
76}
77
78func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
79 if b.childProvider == nil {
80 return nil, errors.New("unsupported model for bedrock provider")
81 }
82 return b.childProvider.send(ctx, messages, tools)
83}
84
85func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
86 eventChan := make(chan ProviderEvent)
87
88 if b.childProvider == nil {
89 go func() {
90 eventChan <- ProviderEvent{
91 Type: EventError,
92 Error: errors.New("unsupported model for bedrock provider"),
93 }
94 close(eventChan)
95 }()
96 return eventChan
97 }
98
99 return b.childProvider.stream(ctx, messages, tools)
100}