Skip to content

Commit

Permalink
Support range search, fix milvus-io#245
Browse files Browse the repository at this point in the history
  • Loading branch information
matrixji committed Jan 13, 2024
1 parent 3adb81a commit 85f6663
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 7 deletions.
9 changes: 7 additions & 2 deletions src/impl/MilvusClientImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "TypeUtils.h"
#include "common.pb.h"
#include "milvus.grpc.pb.h"
#include "milvus.pb.h"
#include "schema.pb.h"

Expand Down Expand Up @@ -741,7 +740,13 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result

kv_pair = rpc_request.add_search_params();
kv_pair->set_key(milvus::KeyParams());
kv_pair->set_value(arguments.ExtraParams());
// merge extra params with range search
auto json = nlohmann::json::parse(arguments.ExtraParams());
if (arguments.RangeSearch()) {
json["range_filter"] = arguments.RangeFilter();
json["radius"] = arguments.Radius();
}
kv_pair->set_value(json.dump());

rpc_request.set_travel_timestamp(arguments.TravelTimestamp());
rpc_request.set_guarantee_timestamp(arguments.GuaranteeTimestamp());
Expand Down
62 changes: 58 additions & 4 deletions src/impl/types/SearchArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "milvus/types/SearchArguments.h"

#include <nlohmann/json.hpp>
#include <utility>

namespace milvus {
namespace {
Expand All @@ -28,7 +29,7 @@ struct Validation {
bool required;

Status
Validate(const SearchArguments& data, std::unordered_map<std::string, int64_t> params) const {
Validate(const SearchArguments&, std::unordered_map<std::string, int64_t> params) const {
auto it = params.find(param);
if (it != params.end()) {
auto value = it->second;
Expand All @@ -43,7 +44,7 @@ struct Validation {
};

Status
validate(const SearchArguments& data, std::unordered_map<std::string, int64_t> params) {
validate(const SearchArguments& data, const std::unordered_map<std::string, int64_t>& params) {
auto status = Status::OK();
auto validations = {
Validation{"nprobe", 1, 65536, false},
Expand Down Expand Up @@ -128,7 +129,7 @@ SearchArguments::TargetVectors() const {

Status
SearchArguments::AddTargetVector(std::string field_name, const std::string& vector) {
return AddTargetVector(field_name, std::string{vector});
return AddTargetVector(std::move(field_name), std::string{vector});
}

Status
Expand Down Expand Up @@ -223,6 +224,20 @@ SearchArguments::TopK() const {
return topk_;
}

int64_t
SearchArguments::Nprobe() const {
if (extra_params_.find("nprobe") != extra_params_.end()) {
return extra_params_.at("nprobe");
}
return 1;
}

Status
SearchArguments::SetNprobe(int64_t nprobe) {
extra_params_["nprobe"] = nprobe;
return Status::OK();
}

Status
SearchArguments::SetRoundDecimal(int round_decimal) {
round_decimal_ = round_decimal;
Expand All @@ -236,6 +251,12 @@ SearchArguments::RoundDecimal() const {

Status
SearchArguments::SetMetricType(::milvus::MetricType metric_type) {
if (((metric_type == MetricType::IP && metric_type_ == MetricType::L2) ||
(metric_type == MetricType::L2 && metric_type_ == MetricType::IP)) &&
range_search_) {
// switch radius and range_filter
std::swap(radius_, range_filter_);
}
metric_type_ = metric_type;
return Status::OK();
}
Expand All @@ -251,7 +272,7 @@ SearchArguments::AddExtraParam(std::string key, int64_t value) {
return Status::OK();
}

const std::string
std::string
SearchArguments::ExtraParams() const {
return ::nlohmann::json(extra_params_).dump();
}
Expand All @@ -261,4 +282,37 @@ SearchArguments::Validate() const {
return validate(*this, extra_params_);
}

float
SearchArguments::Radius() const {
return radius_;
}

float
SearchArguments::RangeFilter() const {
return range_filter_;
}

Status
SearchArguments::SetRange(float from, float to) {
auto low = std::min(from, to);
auto high = std::max(from, to);
if (metric_type_ == MetricType::IP) {
radius_ = low;
range_filter_ = high;
range_search_ = true;
} else if (metric_type_ == MetricType::L2) {
radius_ = high;
range_filter_ = low;
range_search_ = true;
} else {
return {StatusCode::INVALID_AGUMENT, "Metric type is not supported"};
}
return Status::OK();
}

bool
SearchArguments::RangeSearch() const {
return range_search_;
}

} // namespace milvus
46 changes: 45 additions & 1 deletion src/include/milvus/types/SearchArguments.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ class SearchArguments {
int64_t
TopK() const;

/**
* @brief Get nprobe
*/
int64_t
Nprobe() const;

/**
* @brief Set nprobe
*/
Status
SetNprobe(int64_t nlist);

/**
* @brief Specifies the decimal place of the returned results.
*/
Expand Down Expand Up @@ -197,7 +209,7 @@ class SearchArguments {
/**
* @brief Get extra param
*/
const std::string
std::string
ExtraParams() const;

/**
Expand All @@ -207,6 +219,35 @@ class SearchArguments {
Status
Validate() const;

/**
* @brief Get range radius
* @return
*/
float
Radius() const;

/**
* @brief Get range filter
* @return
*/
float
RangeFilter() const;

/**
* @brief Set range radius
* @param from range radius from
* @param to range radius to
*/
Status
SetRange(float from, float to);

/**
* @brief Get if do range search
* @return
*/
bool
RangeSearch() const;

private:
std::string collection_name_;
std::set<std::string> partition_names_;
Expand All @@ -225,6 +266,9 @@ class SearchArguments {
int64_t topk_{1};
int round_decimal_{-1};

float radius_;
float range_filter_;
bool range_search_{false};
::milvus::MetricType metric_type_{::milvus::MetricType::L2};
};

Expand Down
70 changes: 70 additions & 0 deletions test/st/TestSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,76 @@ TEST_F(MilvusServerTestSearch, SearchWithoutIndex) {
dropCollection();
}

TEST_F(MilvusServerTestSearch, RangeSearch) {
std::vector<milvus::FieldDataPtr> fields{
std::make_shared<milvus::Int16FieldData>("age", std::vector<int16_t>{12, 13, 14, 15, 16, 17, 18}),
std::make_shared<milvus::VarCharFieldData>(
"name", std::vector<std::string>{"Tom", "Jerry", "Lily", "Foo", "Bar", "Jake", "Jonathon"}),
std::make_shared<milvus::FloatVecFieldData>("face", std::vector<std::vector<float>>{
std::vector<float>{0.1f, 0.2f, 0.3f, 0.4f},
std::vector<float>{0.2f, 0.3f, 0.4f, 0.5f},
std::vector<float>{0.3f, 0.4f, 0.5f, 0.6f},
std::vector<float>{0.4f, 0.5f, 0.6f, 0.7f},
std::vector<float>{0.5f, 0.6f, 0.7f, 0.8f},
std::vector<float>{0.6f, 0.7f, 0.8f, 0.9f},
std::vector<float>{0.7f, 0.8f, 0.9f, 1.0f},
})};

createCollectionAndPartitions(true);
auto dml_results = insertRecords(fields);
loadCollection();

milvus::SearchArguments arguments{};
arguments.SetCollectionName(collection_name);
arguments.AddPartitionName(partition_name);
arguments.SetRange(0.3, 1.0);
arguments.SetTopK(10);
arguments.AddOutputField("age");
arguments.AddOutputField("name");
arguments.AddTargetVector("face", std::vector<float>{0.f, 0.f, 0.f, 0.f});
arguments.AddTargetVector("face", std::vector<float>{1.f, 1.f, 1.f, 1.f});
milvus::SearchResults search_results{};
auto status = client_->Search(arguments, search_results);
EXPECT_EQ(status.Message(), "OK");
EXPECT_TRUE(status.IsOk());

const auto& results = search_results.Results();
EXPECT_EQ(results.size(), 2);

// validate results
auto validateScores = [&results](int firstRet, int secondRet) {
// check score should between range
for (const auto& result : results) {
for (const auto& score : result.Scores()) {
EXPECT_GE(score, 0.3);
EXPECT_LE(score, 1.0);
}
}
EXPECT_EQ(results.at(0).Ids().IntIDArray().size(), firstRet);
EXPECT_EQ(results.at(1).Ids().IntIDArray().size(), secondRet);
};

// valid score in range is 3, 2
validateScores(3, 2);

// add fields, then search again, should be 6 and 4
insertRecords(fields);
loadCollection();
status = client_->Search(arguments, search_results);
EXPECT_TRUE(status.IsOk());
validateScores(6, 4);

// add fields twice, and now it should be 12, 8, as limit is 10, then should be 10, 8
insertRecords(fields);
insertRecords(fields);
loadCollection();
status = client_->Search(arguments, search_results);
EXPECT_TRUE(status.IsOk());
validateScores(10, 8);

dropCollection();
}

TEST_F(MilvusServerTestSearch, SearchWithStringFilter) {
std::vector<milvus::FieldDataPtr> fields{
std::make_shared<milvus::Int16FieldData>("age", std::vector<int16_t>{12, 13}),
Expand Down
21 changes: 21 additions & 0 deletions test/ut/TestSearchArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,24 @@ TEST_F(SearchArgumentsTest, ValidateTesting) {
EXPECT_TRUE(status.IsOk());
}
}

TEST_F(SearchArgumentsTest, Nprobe) {
milvus::SearchArguments arguments;
arguments.AddExtraParam("nprobe", 10);
EXPECT_EQ(10, arguments.Nprobe());

arguments.SetNprobe(20);
EXPECT_EQ(20, arguments.Nprobe());
}

TEST_F(SearchArgumentsTest, RangeSearchParams) {
milvus::SearchArguments arguments;
arguments.SetMetricType(milvus::MetricType::IP);
arguments.SetRange(0.1, 0.2);
EXPECT_NEAR(0.1, arguments.Radius(), 0.00001);
EXPECT_NEAR(0.2, arguments.RangeFilter(), 0.00001);

arguments.SetMetricType(milvus::MetricType::L2);
EXPECT_NEAR(0.2, arguments.Radius(), 0.00001);
EXPECT_NEAR(0.1, arguments.RangeFilter(), 0.00001);
}

0 comments on commit 85f6663

Please sign in to comment.