Skip to content

Commit

Permalink
refactory: improvments of typo and some hotfix
Browse files Browse the repository at this point in the history
Signed-off-by: Ji Bin <[email protected]>
  • Loading branch information
matrixji committed Jan 13, 2024
1 parent 0141b50 commit c9312b1
Show file tree
Hide file tree
Showing 52 changed files with 204 additions and 210 deletions.
35 changes: 18 additions & 17 deletions src/impl/MilvusClientImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ MilvusClientImpl::LoadCollection(const std::string& collection_name, int replica
};

auto wait_for_status = [this, &collection_name, &progress_monitor](const proto::common::Status&) {
return waitForStatus(
return WaitForStatus(
[&collection_name, this](Progress& progress) -> Status {
CollectionsInfo collections_info;
auto collection_names = std::vector<std::string>{collection_name};
Expand Down Expand Up @@ -259,7 +259,7 @@ MilvusClientImpl::ShowCollections(const std::vector<std::string>& collection_nam
};

auto post = [&collections_info](const proto::milvus::ShowCollectionsResponse& response) {
for (size_t i = 0; i < response.collection_ids_size(); i++) {
for (int i = 0; i < response.collection_ids_size(); i++) {
auto inmemory_percentage = 0;
if (response.inmemory_percentages_size() > i) {
inmemory_percentage = response.inmemory_percentages(i);
Expand Down Expand Up @@ -331,7 +331,7 @@ MilvusClientImpl::LoadPartitions(const std::string& collection_name, const std::
};

auto wait_for_status = [this, &collection_name, &partition_names, &progress_monitor](const proto::common::Status&) {
return waitForStatus(
return WaitForStatus(
[&collection_name, &partition_names, this](Progress& progress) -> Status {
PartitionsInfo partitions_info;
auto status = ShowPartitions(collection_name, partition_names, partitions_info);
Expand Down Expand Up @@ -422,7 +422,7 @@ MilvusClientImpl::ShowPartitions(const std::string& collection_name, const std::
if (count > 0) {
partitions_info.reserve(count);
}
for (size_t i = 0; i < count; ++i) {
for (int i = 0; i < count; ++i) {
partitions_info.emplace_back(response.partition_names(i), response.partitionids(i),
response.created_timestamps(i), response.inmemory_percentages(i));
}
Expand Down Expand Up @@ -499,7 +499,7 @@ MilvusClientImpl::CreateIndex(const std::string& collection_name, const IndexDes
};

auto wait_for_status = [&collection_name, &index_desc, &progress_monitor, this](const proto::common::Status&) {
return waitForStatus(
return WaitForStatus(
[&collection_name, &index_desc, this](Progress& progress) -> Status {
IndexState index_state;
auto status = GetIndexState(collection_name, index_desc.FieldName(), index_state);
Expand Down Expand Up @@ -539,13 +539,13 @@ MilvusClientImpl::DescribeIndex(const std::string& collection_name, const std::s

auto post = [&index_desc](const proto::milvus::DescribeIndexResponse& response) {
auto count = response.index_descriptions_size();
for (size_t i = 0; i < count; ++i) {
for (int i = 0; i < count; ++i) {
auto& field_name = response.index_descriptions(i).field_name();
auto& index_name = response.index_descriptions(i).index_name();
index_desc.SetFieldName(field_name);
index_desc.SetIndexName(index_name);
auto index_params_size = response.index_descriptions(i).params_size();
for (size_t j = 0; j < index_params_size; ++j) {
for (int j = 0; j < index_params_size; ++j) {
const auto& key = response.index_descriptions(i).params(j).key();
const auto& value = response.index_descriptions(i).params(j).value();
if (key == milvus::KeyIndexType()) {
Expand Down Expand Up @@ -759,19 +759,20 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result
const auto& scores = result_data.scores();
const auto& fields_data = result_data.fields_data();
auto num_of_queries = result_data.num_queries();
std::vector<int64_t> topks(num_of_queries, result_data.top_k());
std::vector<int> topks{};
topks.reserve(result_data.topks_size());
for (int i = 0; i < result_data.topks_size(); ++i) {
topks[i] = result_data.topks(i);
topks.emplace_back(result_data.topks(i));
}
std::vector<SingleResult> single_results;
single_results.reserve(num_of_queries);
size_t offset{0};
for (int64_t i = 0; i < num_of_queries; ++i) {
int offset{0};
for (int i = 0; i < num_of_queries; ++i) {
std::vector<float> item_scores;
std::vector<FieldDataPtr> item_field_data;
auto item_topk = topks[i];
item_scores.reserve(item_topk);
for (int64_t j = 0; j < item_topk; ++j) {
for (int j = 0; j < item_topk; ++j) {
item_scores.emplace_back(scores.at(offset + j));
}
item_field_data.reserve(fields_data.size());
Expand Down Expand Up @@ -842,7 +843,7 @@ MilvusClientImpl::CalcDistance(const CalcDistanceArguments& arguments, DistanceA
}

// suppose vectors is not empty, already checked by Validate()
data_array->set_dim(vectors[0].size());
data_array->set_dim(static_cast<int>(vectors[0].size()));
} else {
auto data_ptr = std::static_pointer_cast<BinaryVecFieldData>(arg_vectors);
auto& str = *data_array->mutable_binary_vector();
Expand All @@ -853,7 +854,7 @@ MilvusClientImpl::CalcDistance(const CalcDistanceArguments& arguments, DistanceA
for (auto& vector : vectors) {
str.append(vector);
}
data_array->set_dim(dimensions);
data_array->set_dim(static_cast<int>(dimensions));
}

} else if (arg_vectors->Type() == DataType::INT64) {
Expand Down Expand Up @@ -960,7 +961,7 @@ MilvusClientImpl::Flush(const std::vector<std::string>& collection_names, const
return Status::OK();
}

return waitForStatus(
return WaitForStatus(
[&segment_count, &flush_segments, &finished_count, this](Progress& p) -> Status {
p.total_ = segment_count;

Expand Down Expand Up @@ -1117,7 +1118,7 @@ MilvusClientImpl::ManualCompaction(const std::string& collection_name, uint64_t
return status;
}

auto pre = [&collection_name, &travel_timestamp, &collection_desc]() {
auto pre = [&travel_timestamp, &collection_desc]() {
proto::milvus::ManualCompactionRequest rpc_request;
rpc_request.set_collectionid(collection_desc.ID());
rpc_request.set_timetravel(travel_timestamp);
Expand Down Expand Up @@ -1223,7 +1224,7 @@ MilvusClientImpl::ListCredUsers(std::vector<std::string>& users) {
}

Status
MilvusClientImpl::waitForStatus(std::function<Status(Progress&)> query_function,
MilvusClientImpl::WaitForStatus(const std::function<Status(Progress&)>& query_function,
const ProgressMonitor& progress_monitor) {
// no need to check
if (progress_monitor.CheckTimeout() == 0) {
Expand Down
8 changes: 4 additions & 4 deletions src/impl/MilvusClientImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,16 @@ class MilvusClientImpl : public MilvusClient {
* @param [in] progress_monitor timeout setting for waiting progress
* @return Status, the final status
*/
Status
waitForStatus(std::function<Status(Progress&)> query_function, const ProgressMonitor& progress_monitor);
static Status
WaitForStatus(const std::function<Status(Progress&)>& query_function, const ProgressMonitor& progress_monitor);

/**
* @brief template for public api call
* validate -> pre -> rpc -> wait_for_status -> post
*/
template <typename Request, typename Response>
Status
apiHandler(std::function<Status(void)> validate, std::function<Request(void)> pre,
apiHandler(const std::function<Status(void)>& validate, std::function<Request(void)> pre,
Status (MilvusConnection::*rpc)(const Request&, Response&, const GrpcOpts&),
std::function<Status(const Response&)> wait_for_status, std::function<void(const Response&)> post,
const GrpcOpts& options = GrpcOpts{}) {
Expand All @@ -215,7 +215,7 @@ class MilvusClientImpl : public MilvusClient {
auto status = std::bind(rpc, connection_.get(), std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3)(rpc_request, rpc_response, options);
if (!status.IsOk()) {
// resp's status already checked in connection class
// response's status already checked in connection class
return status;
}

Expand Down
13 changes: 6 additions & 7 deletions src/impl/MilvusConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <memory>

#include "grpcpp/security/credentials.h"
#include "milvus/types/ConnectParam.h"

using grpc::Channel;
using grpc::ClientContext;
Expand Down Expand Up @@ -63,8 +62,8 @@ MilvusConnection::~MilvusConnection() {
Status
MilvusConnection::Connect(const ConnectParam& param) {
authorization_value_ = param.Authorizations();
std::shared_ptr<grpc::ChannelCredentials> creds{nullptr};
auto& uri = param.Uri();
std::shared_ptr<grpc::ChannelCredentials> credentials{nullptr};
auto uri = param.Uri();

::grpc::ChannelArguments args;
args.SetMaxSendMessageSize(-1); // max send message size: 2GB
Expand All @@ -74,12 +73,12 @@ MilvusConnection::Connect(const ConnectParam& param) {
if (!param.ServerName().empty()) {
args.SetSslTargetNameOverride(param.ServerName());
}
creds = createTlsCredentials(param.Cert(), param.Key(), param.CaCert());
credentials = createTlsCredentials(param.Cert(), param.Key(), param.CaCert());
} else {
creds = ::grpc::InsecureChannelCredentials();
credentials = ::grpc::InsecureChannelCredentials();
}

channel_ = ::grpc::CreateCustomChannel(uri, creds, args);
channel_ = ::grpc::CreateCustomChannel(uri, credentials, args);
auto connected = channel_->WaitForConnected(std::chrono::system_clock::now() +
std::chrono::milliseconds{param.ConnectTimeout()});
if (connected) {
Expand Down Expand Up @@ -175,7 +174,7 @@ MilvusConnection::DropPartition(const proto::milvus::DropPartitionRequest& reque
Status
MilvusConnection::HasPartition(const proto::milvus::HasPartitionRequest& request, proto::milvus::BoolResponse& response,
const GrpcContextOptions& options) {
return grpcCall("HasParition", &Stub::HasPartition, request, response, options);
return grpcCall("HasPartition", &Stub::HasPartition, request, response, options);
}

Status
Expand Down
20 changes: 10 additions & 10 deletions src/impl/MilvusConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,23 +238,23 @@ class MilvusConnection {
std::shared_ptr<grpc::Channel> channel_;
std::string authorization_value_{};

Status
statusByProtoResponse(const proto::common::Status& status) {
if (status.error_code() != proto::common::ErrorCode::Success) {
static Status
StatusByProtoResponse(const proto::common::Status& status) {
if (status.code() != proto::common::ErrorCode::Success) {
return Status{StatusCode::SERVER_FAILED, status.reason()};
}
return Status::OK();
}

template <typename Response>
Status
statusByProtoResponse(const Response& response) {
static Status
StatusByProtoResponse(const Response& response) {
const auto& status = response.status();
return statusByProtoResponse(status);
return StatusByProtoResponse(status);
}

StatusCode
statusCodeFromGrpcStatus(const ::grpc::Status& grpc_status) {
static StatusCode
StatusCodeFromGrpcStatus(const ::grpc::Status& grpc_status) {
if (grpc_status.error_code() == ::grpc::StatusCode::DEADLINE_EXCEEDED) {
return StatusCode::TIMEOUT;
}
Expand Down Expand Up @@ -284,10 +284,10 @@ class MilvusConnection {
::grpc::Status grpc_status = (stub_.get()->*func)(&context, request, &response);

if (!grpc_status.ok()) {
return {statusCodeFromGrpcStatus(grpc_status), grpc_status.error_message()};
return {StatusCodeFromGrpcStatus(grpc_status), grpc_status.error_message()};
}

return statusByProtoResponse(response);
return StatusByProtoResponse(response);
}
};

Expand Down
2 changes: 0 additions & 2 deletions src/impl/Status.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include "milvus/Status.h"

#include <cstring>

namespace milvus {

Status::Status(StatusCode code, std::string msg) : code_(code), msg_(std::move(msg)) {
Expand Down
6 changes: 3 additions & 3 deletions src/impl/TypeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ CreateProtoFieldData(const BinaryVecFieldData& field) {
for (const auto& item : data) {
std::copy(item.begin(), item.end(), std::back_inserter(vectors_data));
}
ret->set_dim(dim);
ret->set_dim(static_cast<int>(dim));
return ret;
}

Expand All @@ -416,11 +416,11 @@ CreateProtoFieldData(const FloatVecFieldData& field) {
auto& data = field.Data();
auto dim = data.front().size();
auto& vectors_data = *(ret->mutable_float_vector()->mutable_data());
vectors_data.Reserve(data.size() * dim);
vectors_data.Reserve(static_cast<int>(data.size() * dim));
for (const auto& item : data) {
vectors_data.Add(item.begin(), item.end());
}
ret->set_dim(dim);
ret->set_dim(static_cast<int>(dim));
return ret;
}

Expand Down
2 changes: 0 additions & 2 deletions src/impl/types/CalcDistanceArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include "milvus/types/CalcDistanceArguments.h"

#include <algorithm>
#include <functional>
#include <set>
#include <unordered_map>

Expand Down
2 changes: 1 addition & 1 deletion src/impl/types/CollectionStat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ CollectionStat::RowCount() const {
// TODO: throw exception or log
return 0;
}
return std::atoll(iter->second.c_str());
return std::strtoll(iter->second.c_str(), nullptr, 10);
}

void
Expand Down
2 changes: 0 additions & 2 deletions src/impl/types/CompactionState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include "milvus/types/CompactionState.h"

#include <string>

namespace milvus {

CompactionState::CompactionState() = default;
Expand Down
2 changes: 1 addition & 1 deletion src/impl/types/ConnectParam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ ConnectParam::Port() const {
return port_;
}

const std::string
std::string
ConnectParam::Uri() const {
return host_ + ":" + std::to_string(port_);
}
Expand Down
4 changes: 2 additions & 2 deletions src/impl/types/FieldSchema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ uint32_t
FieldSchema::Dimension() const {
auto iter = type_params_.find(FieldDim());
if (iter != type_params_.end()) {
return std::atol(iter->second.c_str());
return std::strtol(iter->second.c_str(), nullptr, 10);
}
return 0;
}
Expand All @@ -124,7 +124,7 @@ uint32_t
FieldSchema::MaxLength() const {
auto iter = type_params_.find(FieldMaxLength());
if (iter != type_params_.end()) {
return std::atol(iter->second.c_str());
return std::strtol(iter->second.c_str(), nullptr, 10);
}
return 0;
}
Expand Down
Loading

0 comments on commit c9312b1

Please sign in to comment.