1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type bedrockClient struct {
14 providerOptions providerClientOptions
15 childProvider ProviderClient
16}
17
18type BedrockClient ProviderClient
19
20func newBedrockClient(opts providerClientOptions) BedrockClient {
21 // Get AWS region from environment
22 region := opts.extraParams["region"]
23 if region == "" {
24 region = "us-east-1" // default region
25 }
26 if len(region) < 2 {
27 return &bedrockClient{
28 providerOptions: opts,
29 childProvider: nil, // Will cause an error when used
30 }
31 }
32
33 // Prefix the model name with region
34 regionPrefix := region[:2]
35 modelName := opts.model.ID
36 opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
37
38 // Determine which provider to use based on the model
39 if strings.Contains(string(opts.model.ID), "anthropic") {
40 // Create Anthropic client with Bedrock configuration
41 anthropicOpts := opts
42 // TODO: later find a way to check if the AWS account has caching enabled
43 opts.disableCache = true // Disable cache for Bedrock
44 return &bedrockClient{
45 providerOptions: opts,
46 childProvider: newAnthropicClient(anthropicOpts, true),
47 }
48 }
49
50 // Return client with nil childProvider if model is not supported
51 // This will cause an error when used
52 return &bedrockClient{
53 providerOptions: opts,
54 childProvider: nil,
55 }
56}
57
58func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
59 if b.childProvider == nil {
60 return nil, errors.New("unsupported model for bedrock provider")
61 }
62 return b.childProvider.send(ctx, messages, tools)
63}
64
65func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
66 eventChan := make(chan ProviderEvent)
67
68 if b.childProvider == nil {
69 go func() {
70 eventChan <- ProviderEvent{
71 Type: EventError,
72 Error: errors.New("unsupported model for bedrock provider"),
73 }
74 close(eventChan)
75 }()
76 return eventChan
77 }
78
79 return b.childProvider.stream(ctx, messages, tools)
80}