bedrock.go

 1package provider
 2
 3import (
 4	"context"
 5	"errors"
 6	"fmt"
 7	"strings"
 8
 9	"github.com/charmbracelet/crush/internal/llm/tools"
10	"github.com/charmbracelet/crush/internal/message"
11)
12
13type bedrockClient struct {
14	providerOptions providerClientOptions
15	childProvider   ProviderClient
16}
17
18type BedrockClient ProviderClient
19
20func newBedrockClient(opts providerClientOptions) BedrockClient {
21	// Get AWS region from environment
22	region := opts.extraParams["region"]
23	if region == "" {
24		region = "us-east-1" // default region
25	}
26	if len(region) < 2 {
27		return &bedrockClient{
28			providerOptions: opts,
29			childProvider:   nil, // Will cause an error when used
30		}
31	}
32
33	// Prefix the model name with region
34	regionPrefix := region[:2]
35	modelName := opts.model.ID
36	opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
37
38	// Determine which provider to use based on the model
39	if strings.Contains(string(opts.model.ID), "anthropic") {
40		// Create Anthropic client with Bedrock configuration
41		anthropicOpts := opts
42		// TODO: later find a way to check if the AWS account has caching enabled
43		opts.disableCache = true // Disable cache for Bedrock
44		return &bedrockClient{
45			providerOptions: opts,
46			childProvider:   newAnthropicClient(anthropicOpts, true),
47		}
48	}
49
50	// Return client with nil childProvider if model is not supported
51	// This will cause an error when used
52	return &bedrockClient{
53		providerOptions: opts,
54		childProvider:   nil,
55	}
56}
57
58func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
59	if b.childProvider == nil {
60		return nil, errors.New("unsupported model for bedrock provider")
61	}
62	return b.childProvider.send(ctx, messages, tools)
63}
64
65func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
66	eventChan := make(chan ProviderEvent)
67
68	if b.childProvider == nil {
69		go func() {
70			eventChan <- ProviderEvent{
71				Type:  EventError,
72				Error: errors.New("unsupported model for bedrock provider"),
73			}
74			close(eventChan)
75		}()
76		return eventChan
77	}
78
79	return b.childProvider.stream(ctx, messages, tools)
80}