diff --git a/examples/ollama/main.go b/examples/ollama/main.go index 32844de..6463096 100644 --- a/examples/ollama/main.go +++ b/examples/ollama/main.go @@ -48,8 +48,12 @@ func main() { ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program." query := "When was the moon landing?." + headers := map[string]string{ + "Content-Type": "application/json", + } + // Embed the query using Ollama Embedding backend - embedding, err := embeddingBackend.Embed(ctx, ragContent) + embedding, err := embeddingBackend.Embed(ctx, ragContent, headers) if err != nil { log.Fatalf("Error generating embedding: %v", err) } @@ -63,7 +67,7 @@ func main() { log.Println("Vector Document generated") // Embed the query using the specified embedding backend - queryEmbedding, err := embeddingBackend.Embed(ctx, query) + queryEmbedding, err := embeddingBackend.Embed(ctx, query, headers) if err != nil { log.Fatalf("Error generating query embedding: %v", err) } diff --git a/examples/qdrant/main.go b/examples/qdrant/main.go index 9868d88..75bb30b 100644 --- a/examples/qdrant/main.go +++ b/examples/qdrant/main.go @@ -52,8 +52,13 @@ func main() { ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program." userQuery := "When was the moon landing?." + // Set the headers for the embedding request + headers := map[string]string{ + "Content-Type": "application/json", + } + // Embed the query using Ollama Embedding backend - embedding, err := embeddingBackend.Embed(ctx, ragContent) + embedding, err := embeddingBackend.Embed(ctx, ragContent, headers) if err != nil { log.Fatalf("Error generating embedding: %v", err) } @@ -70,7 +75,7 @@ func main() { log.Println("Document inserted successfully.") // Embed the query using the specified embedding backend - queryEmbedding, err := embeddingBackend.Embed(ctx, userQuery) + queryEmbedding, err := embeddingBackend.Embed(ctx, userQuery, headers) if err != nil { log.Fatalf("Error generating query embedding: %v", err) } diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 434b4db..6afe96d 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -138,7 +138,7 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, e } // Embed generates embeddings for the given input text using the Ollama API. -func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, error) { +func (o *OllamaBackend) Embed(ctx context.Context, input string, headers map[string]string) ([]float32, error) { url := o.BaseURL + embedEndpoint reqBody := map[string]interface{}{ "model": o.Model, @@ -154,7 +154,11 @@ func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, err if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Content-Type", "application/json") + + // Add headers to the request + for key, value := range headers { + req.Header.Set(key, value) + } resp, err := o.Client.Do(req) if err != nil { diff --git a/pkg/backend/ollama_backend_test.go b/pkg/backend/ollama_backend_test.go index eec5bbf..b2ff45b 100644 --- a/pkg/backend/ollama_backend_test.go +++ b/pkg/backend/ollama_backend_test.go @@ -152,7 +152,11 @@ func TestOllamaEmbed(t *testing.T) { ctx := context.Background() input := testEmbeddingText - embedding, err := backend.Embed(ctx, input) + headers := map[string]string{ + "Content-Type": contentTypeJSON, + } + + embedding, err := backend.Embed(ctx, input, headers) if err != nil { t.Fatalf("Embed returned error: %v", err) }