diff --git a/swiftide-core/src/document.rs b/swiftide-core/src/document.rs new file mode 100644 index 00000000..fcee2285 --- /dev/null +++ b/swiftide-core/src/document.rs @@ -0,0 +1,168 @@ +//! Documents are the main data structure that is retrieved via the query pipeline +//! +//! Retrievers are expected to eagerly set any configured metadata on the document, with the same +//! field name used during indexing if applicable. +use std::fmt; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::{metadata::Metadata, util::debug_long_utf8}; + +/// A document represents a single unit of retrieved text +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Builder)] +#[builder(setter(into))] +pub struct Document { + #[builder(default)] + metadata: Metadata, + content: String, +} + +impl From for serde_json::Value { + fn from(document: Document) -> Self { + serde_json::json!({ + "metadata": document.metadata, + "content": document.content, + }) + } +} + +impl From<&Document> for serde_json::Value { + fn from(document: &Document) -> Self { + serde_json::json!({ + "metadata": document.metadata, + "content": document.content, + }) + } +} + +impl PartialOrd for Document { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.content.cmp(&other.content)) + } +} + +impl Ord for Document { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.content.cmp(&other.content) + } +} + +impl fmt::Debug for Document { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Document") + .field("metadata", &self.metadata) + .field("content", &debug_long_utf8(&self.content, 100)) + .finish() + } +} + +impl> From for Document { + fn from(value: T) -> Self { + Document::new(value.as_ref(), None) + } +} + +impl Document { + pub fn new(content: impl Into, metadata: Option) -> Self { + Self { + metadata: metadata.unwrap_or_default(), + content: content.into(), + } + } + + pub fn builder() -> DocumentBuilder { + DocumentBuilder::default() + } + + pub fn content(&self) -> &str { + &self.content + } + + pub fn metadata(&self) -> &Metadata { + &self.metadata + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::Metadata; + + #[test] + fn test_document_creation() { + let content = "Test content"; + let metadata = Metadata::from([("some", "metadata")]); + let document = Document::new(content, Some(metadata.clone())); + + assert_eq!(document.content(), content); + assert_eq!(document.metadata(), &metadata); + } + + #[test] + fn test_document_default_metadata() { + let content = "Test content"; + let document = Document::new(content, None); + + assert_eq!(document.content(), content); + assert_eq!(document.metadata(), &Metadata::default()); + } + + #[test] + fn test_document_from_str() { + let content = "Test content"; + let document: Document = content.into(); + + assert_eq!(document.content(), content); + assert_eq!(document.metadata(), &Metadata::default()); + } + + #[test] + fn test_document_partial_ord() { + let doc1 = Document::new("A", None); + let doc2 = Document::new("B", None); + + assert!(doc1 < doc2); + } + + #[test] + fn test_document_ord() { + let doc1 = Document::new("A", None); + let doc2 = Document::new("B", None); + + assert!(doc1.cmp(&doc2) == std::cmp::Ordering::Less); + } + + #[test] + fn test_document_debug() { + let content = "Test content"; + let document = Document::new(content, None); + let debug_str = format!("{document:?}"); + + assert!(debug_str.contains("Document")); + assert!(debug_str.contains("metadata")); + assert!(debug_str.contains("content")); + } + + #[test] + fn test_document_to_json() { + let content = "Test content"; + let metadata = Metadata::from([("some", "metadata")]); + let document = Document::new(content, Some(metadata.clone())); + let json_value: serde_json::Value = document.into(); + + assert_eq!(json_value["content"], content); + assert_eq!(json_value["metadata"], serde_json::json!(metadata)); + } + + #[test] + fn test_document_ref_to_json() { + let content = "Test content"; + let metadata = Metadata::from([("some", "metadata")]); + let document = Document::new(content, Some(metadata.clone())); + let json_value: serde_json::Value = (&document).into(); + + assert_eq!(json_value["content"], content); + assert_eq!(json_value["metadata"], serde_json::json!(metadata)); + } +} diff --git a/swiftide-core/src/lib.rs b/swiftide-core/src/lib.rs index 07da7194..9d4536ee 100644 --- a/swiftide-core/src/lib.rs +++ b/swiftide-core/src/lib.rs @@ -12,6 +12,7 @@ pub mod query_traits; mod search_strategies; pub mod type_aliases; +pub mod document; pub mod prompt; pub use type_aliases::*; diff --git a/swiftide-core/src/metadata.rs b/swiftide-core/src/metadata.rs index a9de7755..b13d64fb 100644 --- a/swiftide-core/src/metadata.rs +++ b/swiftide-core/src/metadata.rs @@ -9,7 +9,7 @@ use serde::Deserializer; use crate::util::debug_long_utf8; -#[derive(Clone, Default, PartialEq)] +#[derive(Clone, Default, PartialEq, Eq)] pub struct Metadata { inner: BTreeMap, } @@ -53,6 +53,10 @@ impl Metadata { pub fn into_values(self) -> IntoValues { self.inner.into_values() } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } } impl Extend<(K, V)> for Metadata diff --git a/swiftide-core/src/query.rs b/swiftide-core/src/query.rs index 5e2e0b9a..2e7244ed 100644 --- a/swiftide-core/src/query.rs +++ b/swiftide-core/src/query.rs @@ -7,9 +7,7 @@ //! `states::Answered`: The query has been answered use derive_builder::Builder; -use crate::{util::debug_long_utf8, Embedding, SparseEmbedding}; - -type Document = String; +use crate::{document::Document, util::debug_long_utf8, Embedding, SparseEmbedding}; /// A query is the main object going through a query pipeline /// @@ -24,6 +22,7 @@ pub struct Query { original: String, #[builder(default = "self.original.clone().unwrap_or_default()")] current: String, + #[builder(default = STATE::default())] state: STATE, #[builder(default)] transformation_history: Vec, @@ -34,6 +33,12 @@ pub struct Query { #[builder(default)] pub sparse_embedding: Option, + + /// Documents the query will operate on + /// + /// A query can retrieve multiple times, accumulating documents + #[builder(default)] + documents: Vec, } impl std::fmt::Debug for Query { @@ -71,6 +76,7 @@ impl Query { transformation_history: self.transformation_history, embedding: self.embedding, sparse_embedding: self.sparse_embedding, + documents: self.documents, } } @@ -78,6 +84,34 @@ impl Query { pub fn history(&self) -> &Vec { &self.transformation_history } + + /// Returns the current documents that will be used as context for answer generation + pub fn documents(&self) -> &[Document] { + &self.documents + } + + /// Returns the current documents as mutable + pub fn documents_mut(&mut self) -> &mut Vec { + &mut self.documents + } +} + +impl Query { + /// Add retrieved documents and transition to `states::Retrieved` + pub fn retrieved_documents(mut self, documents: Vec) -> Query { + self.documents.extend(documents.clone()); + self.transformation_history + .push(TransformationEvent::Retrieved { + before: self.current.clone(), + after: String::new(), + documents, + }); + + let state = states::Retrieved; + + self.current.clear(); + self.transition_to(state) + } } impl Query { @@ -100,21 +134,6 @@ impl Query { self.current = new_query; } - - /// Add retrieved documents and transition to `states::Retrieved` - pub fn retrieved_documents(mut self, documents: Vec) -> Query { - self.transformation_history - .push(TransformationEvent::Retrieved { - before: self.current.clone(), - after: String::new(), - documents: documents.clone(), - }); - - let state = states::Retrieved { documents }; - - self.current.clear(); - self.transition_to(state) - } } impl Query { @@ -135,17 +154,11 @@ impl Query { self.current = new_response; } - /// Returns the last retrieved documents - pub fn documents(&self) -> &[Document] { - &self.state.documents - } - /// Transition the query to `states::Answered` #[must_use] - pub fn answered(self, answer: impl Into) -> Query { - let state = states::Answered { - answer: answer.into(), - }; + pub fn answered(mut self, answer: impl Into) -> Query { + self.current = answer.into(); + let state = states::Answered; self.transition_to(state) } } @@ -157,66 +170,37 @@ impl Query { /// Returns the answer of the query pub fn answer(&self) -> &str { - &self.state.answer + &self.current } } /// Marker trait for query states -pub trait QueryState: Send + Sync {} +pub trait QueryState: Send + Sync + Default {} +/// Marker trait for query states that can still retrieve +pub trait CanRetrieve: QueryState {} /// States of a query pub mod states { - use crate::util::debug_long_utf8; - - use super::Builder; - use super::Document; - use super::QueryState; + use super::{CanRetrieve, QueryState}; - #[derive(Debug, Default, Clone)] + #[derive(Debug, Default, Clone, PartialEq)] /// The query is pending and has not been used pub struct Pending; - #[derive(Default, Clone, Builder, PartialEq)] - #[builder(setter(into))] + #[derive(Debug, Default, Clone, PartialEq)] /// Documents have been retrieved - pub struct Retrieved { - pub(crate) documents: Vec, - } - - impl std::fmt::Debug for Retrieved { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Retrieved") - .field("num_documents", &self.documents.len()) - .field( - "documents", - &self - .documents - .iter() - .map(|d| debug_long_utf8(d, 100)) - .collect::>(), - ) - .finish() - } - } + pub struct Retrieved; - #[derive(Default, Clone, Builder, PartialEq)] - #[builder(setter(into))] + #[derive(Debug, Default, Clone, PartialEq)] /// The query has been answered - pub struct Answered { - pub(crate) answer: String, - } - - impl std::fmt::Debug for Answered { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Answered") - .field("answer", &debug_long_utf8(&self.answer, 100)) - .finish() - } - } + pub struct Answered; impl QueryState for Pending {} impl QueryState for Retrieved {} impl QueryState for Answered {} + + impl CanRetrieve for Pending {} + impl CanRetrieve for Retrieved {} } impl> From for Query { @@ -301,7 +285,7 @@ mod tests { #[test] fn test_query_retrieved_documents() { let query = Query::::from("test query"); - let documents = vec!["doc1".to_string(), "doc2".to_string()]; + let documents: Vec = vec!["doc1".into(), "doc2".into()]; let query = query.retrieved_documents(documents.clone()); assert_eq!(query.documents(), &documents); assert_eq!(query.history().len(), 1); @@ -323,7 +307,7 @@ mod tests { #[test] fn test_query_transformed_response() { let query = Query::::from("test query"); - let documents = vec!["doc1".to_string(), "doc2".to_string()]; + let documents = vec!["doc1".into(), "doc2".into()]; let mut query = query.retrieved_documents(documents.clone()); query.transformed_response("new response"); @@ -342,7 +326,7 @@ mod tests { #[test] fn test_query_answered() { let query = Query::::from("test query"); - let documents = vec!["doc1".to_string(), "doc2".to_string()]; + let documents = vec!["doc1".into(), "doc2".into()]; let query = query.retrieved_documents(documents); let query = query.answered("the answer"); diff --git a/swiftide-integrations/src/lancedb/mod.rs b/swiftide-integrations/src/lancedb/mod.rs index 9eb60657..04130355 100644 --- a/swiftide-integrations/src/lancedb/mod.rs +++ b/swiftide-integrations/src/lancedb/mod.rs @@ -21,6 +21,8 @@ See examples for more information. Implements `Persist` and `Retrieve`. +If you want to store / retrieve metadata in Lance, the columns can be defined with `with_metadata`. + Note: For querying large tables you manually need to create an index. You can get an active connection via `get_connection`. diff --git a/swiftide-integrations/src/lancedb/retrieve.rs b/swiftide-integrations/src/lancedb/retrieve.rs index 1a65fc99..0de3b9b6 100644 --- a/swiftide-integrations/src/lancedb/retrieve.rs +++ b/swiftide-integrations/src/lancedb/retrieve.rs @@ -1,10 +1,13 @@ -use anyhow::{Context as _, Result}; +use anyhow::Result; +use arrow::datatypes::SchemaRef; use arrow_array::StringArray; use async_trait::async_trait; use futures_util::TryStreamExt; use itertools::Itertools; use lancedb::query::{ExecutableQuery, QueryBase}; use swiftide_core::{ + document::Document, + indexing::Metadata, querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, Retrieve, }; @@ -57,24 +60,49 @@ impl Retrieve> for LanceDB { query_builder = query_builder.only_if(filter); } - let result = query_builder + let batches = query_builder .execute() .await? .try_collect::>() .await?; - let Some(recordbatch) = result.first() else { - return Ok(query.retrieved_documents(vec![])); - }; - - let documents: Vec = recordbatch - .column_by_name("chunk") - .and_then(|raw_array| raw_array.as_any().downcast_ref::()) - .context("Could not cast documents to strings")? - .into_iter() - .flatten() - .map_into() - .collect(); + let mut documents = vec![]; + + for batch in batches { + let schema: SchemaRef = batch.schema(); + + for row_idx in 0..batch.num_rows() { + let mut metadata = Metadata::default(); + let mut content = String::new(); + + for (col_idx, field) in schema.fields().iter().enumerate() { + let column = batch.column(col_idx); + + if let Some(array) = column.as_any().downcast_ref::() { + if field.name() == "chunk" { + // Extract the "content" field + content = array.value(row_idx).to_string(); + } else { + // Assume other fields are part of the metadata + let value = array.value(row_idx).to_string(); + metadata.insert(field.name().clone(), value); + } + } else { + // Handle other array types as necessary + // TODO: Can't we just downcast to serde::Value or fail? + } + } + + documents.push(Document::new( + content, + if metadata.is_empty() { + None + } else { + Some(metadata) + }, + )); + } + } Ok(query.retrieved_documents(documents)) } diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs index ed96232c..68ac0be7 100644 --- a/swiftide-integrations/src/pgvector/mod.rs +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -6,6 +6,7 @@ //! - Efficient vector storage and indexing //! - Connection pooling with automatic retries //! - Batch operations for optimized performance +//! - Metadata included in retrieval //! //! The functionality is primarily used through the [`PgVector`] client, which implements //! the [`Persist`] trait for seamless integration with indexing and query pipelines. @@ -192,6 +193,7 @@ mod tests { use futures_util::TryStreamExt; use std::collections::HashSet; use swiftide_core::{ + document::Document, indexing::{self, EmbedMode, EmbeddedField}, querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, Persist, Retrieve, @@ -247,8 +249,13 @@ mod tests { assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"content1".to_string())); - assert!(result.documents().contains(&"content2".to_string())); + let contents = result + .documents() + .iter() + .map(Document::content) + .collect::>(); + assert!(contents.contains(&"content1")); + assert!(contents.contains(&"content2")); // Additional test with priority filter let search_strategy = @@ -260,8 +267,13 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"content1".to_string())); - assert!(result.documents().contains(&"content3".to_string())); + let contents = result + .documents() + .iter() + .map(Document::content) + .collect::>(); + assert!(contents.contains(&"content1")); + assert!(contents.contains(&"content3")); } #[test_log::test(tokio::test)] @@ -317,8 +329,13 @@ mod tests { // Verify that similar vectors are retrieved first assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"base_content".to_string())); - assert!(result.documents().contains(&"similar_content".to_string())); + let contents = result + .documents() + .iter() + .map(Document::content) + .collect::>(); + assert!(contents.contains(&"base_content")); + assert!(contents.contains(&"similar_content")); } #[test_case( @@ -443,7 +460,12 @@ mod tests { if test_case.expected_in_results { assert!( - result.documents().contains(&test_case.chunk.to_string()), + result + .documents() + .iter() + .map(Document::content) + .collect::>() + .contains(&test_case.chunk), "Document should be found in results for field {field}", ); } diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs index 6b9973ae..ab634a83 100644 --- a/swiftide-integrations/src/pgvector/persist.rs +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -5,6 +5,8 @@ //! - Single-node storage operations //! - Optimized batch storage with configurable batch sizes //! +//! NOTE: Persisting and retrieving metadata is not supported at the moment. +//! //! The implementation ensures thread-safe concurrent access and handles //! connection management automatically. use crate::pgvector::PgVector; diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs index ef55b68d..a67349d7 100644 --- a/swiftide-integrations/src/pgvector/retrieve.rs +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -1,9 +1,11 @@ -use crate::pgvector::{PgVector, PgVectorBuilder}; +use crate::pgvector::{FieldConfig, PgVector, PgVectorBuilder}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use pgvector::Vector; -use sqlx::{prelude::FromRow, types::Uuid}; +use sqlx::{prelude::FromRow, types::Uuid, Column, Row}; use swiftide_core::{ + document::Document, + indexing::Metadata, querying::{ search_strategies::{CustomStrategy, SimilaritySingleEmbedding}, states, Query, @@ -12,10 +14,46 @@ use swiftide_core::{ }; #[allow(dead_code)] -#[derive(Debug, Clone, FromRow)] +#[derive(Debug, Clone)] struct VectorSearchResult { id: Uuid, chunk: String, + metadata: Metadata, +} + +impl From for Document { + fn from(val: VectorSearchResult) -> Self { + Document::new(val.chunk, Some(val.metadata)) + } +} + +impl FromRow<'_, sqlx::postgres::PgRow> for VectorSearchResult { + fn from_row(row: &sqlx::postgres::PgRow) -> Result { + let mut metadata = Metadata::default(); + + // Metadata fields are stored each as prefixed meta_ fields. Perhaps we should add a single + // metadata field instead of multiple fields. + for column in row.columns() { + if column.name().starts_with("meta_") { + row.try_get::(column.name())? + .as_object() + .and_then(|object| { + object.keys().collect::>().first().map(|key| { + metadata.insert( + key.to_owned(), + object.get(key.as_str()).expect("infallible").clone(), + ); + }) + }); + } + } + + Ok(VectorSearchResult { + id: row.try_get("id")?, + chunk: row.try_get("chunk")?, + metadata, + }) + } } #[allow(clippy::redundant_closure_for_method_calls)] @@ -40,6 +78,12 @@ impl Retrieve> for PgVector { let default_columns: Vec<_> = PgVectorBuilder::default_fields() .iter() .map(|f| f.field_name().to_string()) + .chain( + self.fields + .iter() + .filter(|f| matches!(f, FieldConfig::Metadata(_))) + .map(|f| f.field_name().to_string()), + ) .collect(); // Start building the SQL query @@ -89,7 +133,7 @@ impl Retrieve> for PgVector { .fetch_all(pool) .await?; - let docs = data.into_iter().map(|r| r.chunk).collect(); + let docs = data.into_iter().map(Into::into).collect(); Ok(query_state.retrieved_documents(docs)) } @@ -132,7 +176,7 @@ impl Retrieve>> for P .map_err(|e| anyhow!("Failed to execute search query: {}", e))?; // Transform results into documents - let documents = results.into_iter().map(|r| r.chunk).collect(); + let documents = results.into_iter().map(Into::into).collect(); // Update query state with retrieved documents Ok(query.retrieved_documents(documents)) @@ -212,4 +256,52 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 0); } + + #[test_log::test(tokio::test)] + async fn test_retrieve_docs_with_metadata() { + let test_context = TestContext::setup_with_cfg( + vec!["other", "text"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + let nodes = vec![indexing::Node::new("test_query1") + .with_metadata([ + ("other", serde_json::Value::from(10)), + ("text", serde_json::Value::from("some text")), + ]) + .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) + .to_owned()]; + + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let search_strategy = SimilaritySingleEmbedding::<()>::default(); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 1); + + let doc = result.documents().first().unwrap(); + assert_eq!( + doc.metadata().get("other"), + Some(&serde_json::Value::from(10)) + ); + assert_eq!( + doc.metadata().get("text"), + Some(&serde_json::Value::from("some text")) + ); + } } diff --git a/swiftide-integrations/src/qdrant/retrieve.rs b/swiftide-integrations/src/qdrant/retrieve.rs index dc193868..890ee0f8 100644 --- a/swiftide-integrations/src/qdrant/retrieve.rs +++ b/swiftide-integrations/src/qdrant/retrieve.rs @@ -1,6 +1,7 @@ -use qdrant_client::qdrant::{self, PrefetchQueryBuilder, SearchPointsBuilder}; +use qdrant_client::qdrant::{self, PrefetchQueryBuilder, ScoredPoint, SearchPointsBuilder}; use swiftide_core::{ - indexing::EmbeddedField, + document::Document, + indexing::{EmbeddedField, Metadata}, prelude::{Result, *}, querying::{ search_strategies::{HybridSearch, SimilaritySingleEmbedding}, @@ -53,13 +54,7 @@ impl Retrieve> for Qdrant { let documents = result .into_iter() - .map(|scored_point| { - Ok(scored_point - .payload - .get("content") - .context("Expected document in qdrant payload")? - .to_string()) - }) + .map(scored_point_into_document) .collect::>>()?; Ok(query.retrieved_documents(documents)) @@ -133,22 +128,30 @@ impl Retrieve for Qdrant { let documents = result .into_iter() - .map(|scored_point| { - let value = scored_point - .payload - .get("content") - .context("Expected document in qdrant payload")?; - - Ok(value - .as_str() - .map_or_else(|| value.to_string(), ToString::to_string)) - }) + .map(scored_point_into_document) .collect::>>()?; Ok(query.retrieved_documents(documents)) } } +fn scored_point_into_document(scored_point: ScoredPoint) -> Result { + let content = scored_point + .payload + .get("content") + .context("Expected document in qdrant payload")? + .to_string(); + + let metadata: Metadata = scored_point + .payload + .into_iter() + .filter(|(k, _)| *k != "content") + .collect::>() + .into(); + + Ok(Document::new(content, Some(metadata))) +} + #[cfg(test)] mod tests { use itertools::Itertools as _; @@ -218,7 +221,12 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 3); assert_eq!( - result.documents().iter().sorted().collect_vec(), + result + .documents() + .iter() + .sorted() + .map(Document::content) + .collect_vec(), // FIXME: The extra quotes should be removed by serde (via qdrant::Value), but they are // not ["\"test_query1\"", "\"test_query2\"", "\"test_query3\""] @@ -236,7 +244,12 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 2); assert_eq!( - result.documents().iter().sorted().collect_vec(), + result + .documents() + .iter() + .sorted() + .map(Document::content) + .collect_vec(), ["\"test_query1\"", "\"test_query2\""] .into_iter() .sorted() diff --git a/swiftide-query/src/answers/simple.rs b/swiftide-query/src/answers/simple.rs index 26709a4b..67bbe767 100644 --- a/swiftide-query/src/answers/simple.rs +++ b/swiftide-query/src/answers/simple.rs @@ -7,6 +7,7 @@ //! as context instead. use std::sync::Arc; use swiftide_core::{ + document::Document, indexing::SimplePrompt, prelude::*, prompt::PromptTemplate, @@ -77,7 +78,12 @@ impl Answer for Simple { #[tracing::instrument(skip_all)] async fn answer(&self, query: Query) -> Result> { let context = if query.current().is_empty() { - &query.documents().join("\n---\n") + &query + .documents() + .iter() + .map(Document::content) + .collect::>() + .join("\n---\n") } else { query.current() }; diff --git a/swiftide-query/src/evaluators/ragas.rs b/swiftide-query/src/evaluators/ragas.rs index b5c5dfa1..3e2787a6 100644 --- a/swiftide-query/src/evaluators/ragas.rs +++ b/swiftide-query/src/evaluators/ragas.rs @@ -120,7 +120,11 @@ impl EvaluationDataSet { .get_mut(question) .ok_or_else(|| anyhow::anyhow!("Question not found"))?; - data.contexts = query.documents().to_vec(); + data.contexts = query + .documents() + .iter() + .map(|d| d.content().to_string()) + .collect::>(); Ok(()) } @@ -236,7 +240,7 @@ impl FromStr for EvaluationDataSet { mod tests { use super::*; use std::sync::Arc; - use swiftide_core::querying::{states, Query, QueryEvaluation}; + use swiftide_core::querying::{Query, QueryEvaluation}; use tokio::sync::RwLock; #[tokio::test] @@ -299,12 +303,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::RetrievedBuilder::default() - .documents(vec!["Rust is a language".to_string()]) - .build() - .unwrap(), - ) + .documents(vec!["Rust is a language".into()]) .build() .unwrap(); let evaluation = QueryEvaluation::RetrieveDocuments(query.clone()); @@ -325,12 +324,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::AnsweredBuilder::default() - .answer("A systems programming language") - .build() - .unwrap(), - ) + .current("A systems programming language") .build() .unwrap(); let evaluation = QueryEvaluation::AnswerQuery(query.clone()); @@ -372,12 +366,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::RetrievedBuilder::default() - .documents(vec!["Rust is a language".to_string()]) - .build() - .unwrap(), - ) + .documents(vec!["Rust is a language".into()]) .build() .unwrap(); dataset @@ -394,12 +383,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::AnsweredBuilder::default() - .answer("A systems programming language") - .build() - .unwrap(), - ) + .current("A systems programming language") .build() .unwrap(); dataset diff --git a/swiftide-query/src/response_transformers/summary.rs b/swiftide-query/src/response_transformers/summary.rs index 7e5bdc94..621c6bc3 100644 --- a/swiftide-query/src/response_transformers/summary.rs +++ b/swiftide-query/src/response_transformers/summary.rs @@ -60,7 +60,7 @@ fn default_prompt() -> PromptTemplate { {% for document in documents -%} --- - {{ document }} + {{ document.content }} --- {% endfor -%} " @@ -91,7 +91,9 @@ impl TransformResponse for Summary { #[cfg(test)] mod test { + use swiftide_core::document::Document; + use super::*; - assert_default_prompt_snapshot!("documents" => vec!["First document", "Second Document"]); + assert_default_prompt_snapshot!("documents" => vec![Document::from("First document"), Document::from("Second Document")]); } diff --git a/swiftide/tests/lancedb.rs b/swiftide/tests/lancedb.rs index 5202c5d0..f2f1b583 100644 --- a/swiftide/tests/lancedb.rs +++ b/swiftide/tests/lancedb.rs @@ -3,7 +3,7 @@ use swiftide::indexing::{ transformers::{metadata_qa_code::NAME as METADATA_QA_CODE_NAME, ChunkCode, MetadataQACode}, EmbeddedField, }; -use swiftide::query::{self, states, Query, TransformationEvent}; +use swiftide::query::{self, states, Query}; use swiftide_indexing::{loaders, transformers, Pipeline}; use swiftide_integrations::{fastembed::FastEmbed, lancedb::LanceDB}; use swiftide_query::{answers, query_transformers, response_transformers}; @@ -33,6 +33,7 @@ async fn test_lancedb() { .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_CODE_NAME) .with_metadata("filter") + .with_metadata("path") .table_name("swiftide_test") .build() .unwrap(); @@ -41,8 +42,10 @@ async fn test_lancedb() { .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(MetadataQACode::new(openai_client.clone())) .then(|mut node: indexing::Node| { + // Add path to metadata, by default, storage will store all metadata fields node.metadata - .insert("filter".to_string(), "true".to_string()); + .insert("path", node.path.display().to_string()); + node.metadata.insert("filter", "true"); Ok(node) }) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) @@ -75,17 +78,12 @@ async fn test_lancedb() { result.answer(), "\n\nHello there, how may I assist you today?" ); - let TransformationEvent::Retrieved { documents, .. } = result - .history() - .iter() - .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) - .unwrap() - else { - panic!("No documents found") - }; + + let retrieved_document = result.documents().first().unwrap(); + assert_eq!(retrieved_document.content(), code); assert_eq!( - documents.first().unwrap(), - "fn main() { println!(\"Hello, World!\"); }" + retrieved_document.metadata().get("path").unwrap(), + codefile.to_str().unwrap() ); } diff --git a/swiftide/tests/pgvector.rs b/swiftide/tests/pgvector.rs index d9cf6b35..391a151d 100644 --- a/swiftide/tests/pgvector.rs +++ b/swiftide/tests/pgvector.rs @@ -2,6 +2,8 @@ //! The tests validate the functionality of the pipeline, ensuring that data is correctly indexed //! and processed from temporary files, database configurations, and simulated environments. +use swiftide_core::document::Document; +use swiftide_integrations::treesitter::metadata_qa_code; use temp_dir::TempDir; use anyhow::{anyhow, Result}; @@ -18,10 +20,7 @@ use swiftide::{ self, pgvector::{FieldConfig, PgVector, PgVectorBuilder, VectorConfig}, }, - query::{ - self, answers, query_transformers, response_transformers, states, Query, - TransformationEvent, - }, + query::{self, answers, query_transformers, response_transformers, states, Query}, }; use swiftide_test_utils::{mock_chat_completions, openai_client}; use wiremock::MockServer; @@ -218,19 +217,21 @@ async fn test_pgvector_retrieve() { result.answer(), "\n\nHello there, how may I assist you today?" ); - let TransformationEvent::Retrieved { documents, .. } = result - .history() - .iter() - .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) - .unwrap() - else { - panic!("No documents found") - }; - assert_eq!( - documents.first().unwrap(), - "fn main() { println!(\"Hello, World!\"); }" - ); + let first_document = result.documents().first().unwrap(); + + let expected = Document::builder() + .content("fn main() { println!(\"Hello, World!\"); }") + .metadata([ + ( + metadata_qa_code::NAME, + "\n\nHello there, how may I assist you today?", + ), + ("filter", "true"), + ]) + .build() + .unwrap(); + assert_eq!(first_document, &expected); } /// Tests the dynamic vector similarity search functionality using PostgreSQL. @@ -393,17 +394,12 @@ async fn test_pgvector_retrieve_dynamic_search() { "\n\nHello there, how may I assist you today?" ); - let TransformationEvent::Retrieved { documents, .. } = result - .history() - .iter() - .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) - .unwrap() - else { - panic!("No documents found") - }; + let first_document = result.documents().first().unwrap(); - assert_eq!( - documents.first().unwrap(), - "fn main() { println!(\"Hello, World!\"); }" - ); + // The custom query explicitly skipped metadata + let expected = Document::builder() + .content("fn main() { println!(\"Hello, World!\"); }") + .build() + .unwrap(); + assert_eq!(first_document, &expected); }