Skip to content

Commit

Permalink
Merge pull request #16 from StacklokLabs/embed-headings
Browse files Browse the repository at this point in the history
Add support for header population for embeddings model
  • Loading branch information
lukehinds authored Nov 11, 2024
2 parents 43820a1 + e93cb10 commit a923b55
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
8 changes: 6 additions & 2 deletions examples/ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ func main() {
ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program."
query := "When was the moon landing?."

headers := map[string]string{
"Content-Type": "application/json",
}

// Embed the query using Ollama Embedding backend
embedding, err := embeddingBackend.Embed(ctx, ragContent)
embedding, err := embeddingBackend.Embed(ctx, ragContent, headers)
if err != nil {
log.Fatalf("Error generating embedding: %v", err)
}
Expand All @@ -63,7 +67,7 @@ func main() {
log.Println("Vector Document generated")

// Embed the query using the specified embedding backend
queryEmbedding, err := embeddingBackend.Embed(ctx, query)
queryEmbedding, err := embeddingBackend.Embed(ctx, query, headers)
if err != nil {
log.Fatalf("Error generating query embedding: %v", err)
}
Expand Down
9 changes: 7 additions & 2 deletions examples/qdrant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ func main() {
ragContent := "According to the Space Exploration Organization's official records, the moon landing occurred on July 20, 2023, during the Artemis Program. This mission marked the first successful crewed lunar landing since the Apollo program."
userQuery := "When was the moon landing?."

// Set the headers for the embedding request
headers := map[string]string{
"Content-Type": "application/json",
}

// Embed the query using Ollama Embedding backend
embedding, err := embeddingBackend.Embed(ctx, ragContent)
embedding, err := embeddingBackend.Embed(ctx, ragContent, headers)
if err != nil {
log.Fatalf("Error generating embedding: %v", err)
}
Expand All @@ -70,7 +75,7 @@ func main() {
log.Println("Document inserted successfully.")

// Embed the query using the specified embedding backend
queryEmbedding, err := embeddingBackend.Embed(ctx, userQuery)
queryEmbedding, err := embeddingBackend.Embed(ctx, userQuery, headers)
if err != nil {
log.Fatalf("Error generating query embedding: %v", err)
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/backend/ollama_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, e
}

// Embed generates embeddings for the given input text using the Ollama API.
func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, error) {
func (o *OllamaBackend) Embed(ctx context.Context, input string, headers map[string]string) ([]float32, error) {
url := o.BaseURL + embedEndpoint
reqBody := map[string]interface{}{
"model": o.Model,
Expand All @@ -154,7 +154,11 @@ func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, err
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

// Add headers to the request
for key, value := range headers {
req.Header.Set(key, value)
}

resp, err := o.Client.Do(req)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/backend/ollama_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ func TestOllamaEmbed(t *testing.T) {
ctx := context.Background()
input := testEmbeddingText

embedding, err := backend.Embed(ctx, input)
headers := map[string]string{
"Content-Type": contentTypeJSON,
}

embedding, err := backend.Embed(ctx, input, headers)
if err != nil {
t.Fatalf("Embed returned error: %v", err)
}
Expand Down

0 comments on commit a923b55

Please sign in to comment.