Skip to content

Commit

Permalink
replace logger inside diskann with vsag logger (#245)
Browse files Browse the repository at this point in the history
Signed-off-by: wxy407827 <[email protected]>
  • Loading branch information
wxyucs authored Jan 6, 2025
1 parent 6852879 commit b11ff63
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 20 deletions.
7 changes: 4 additions & 3 deletions extern/diskann/DiskANN/include/logger_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ class ANNStreamBuf : public std::basic_streambuf<char>
return true; // because stdout and stderr are always open.
}
DISKANN_DLLEXPORT void close();
DISKANN_DLLEXPORT virtual int underflow();
DISKANN_DLLEXPORT virtual int overflow(int c);
DISKANN_DLLEXPORT virtual int sync();
DISKANN_DLLEXPORT virtual int underflow() override;
DISKANN_DLLEXPORT virtual int overflow(int c) override;
DISKANN_DLLEXPORT virtual int sync() override;

private:
FILE *_fp;
char *_buf;
int _bufIndex;
std::mutex _mutex;
LogLevel _logLevel;
std::function<void(LogLevel, const char *)> g_logger;

int flush();
void logImpl(char *str, int numchars);
Expand Down
21 changes: 15 additions & 6 deletions extern/diskann/DiskANN/src/logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
#include "logger_impl.h"
#include "windows_customizations.h"

namespace vsag
{
extern std::function<void(diskann::LogLevel, const char*)>
vsag_get_logger();
} // namespace vsag

namespace diskann
{

Expand All @@ -16,13 +22,7 @@ DISKANN_DLLEXPORT ANNStreamBuf cerrBuff(stderr);

DISKANN_DLLEXPORT std::basic_ostream<char> cout(&coutBuff);
DISKANN_DLLEXPORT std::basic_ostream<char> cerr(&cerrBuff);
std::function<void(LogLevel, const char *)> g_logger;

void SetCustomLogger(std::function<void(LogLevel, const char *)> logger)
{
g_logger = logger;
diskann::cout << "Set Custom Logger" << std::endl;
}

ANNStreamBuf::ANNStreamBuf(FILE *fp)
{
Expand All @@ -40,10 +40,14 @@ ANNStreamBuf::ANNStreamBuf(FILE *fp)

std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char));
setp(_buf, _buf + BUFFER_SIZE - 1);

g_logger = vsag::vsag_get_logger();
g_logger(_logLevel, "diskann switch logger");
}

ANNStreamBuf::~ANNStreamBuf()
{
g_logger = nullptr;
sync();
_fp = nullptr; // we'll not close because we can't.
delete[] _buf;
Expand Down Expand Up @@ -80,8 +84,13 @@ int ANNStreamBuf::flush()
pbump(-num);
return num;
}

void ANNStreamBuf::logImpl(char *str, int num)
{
// remove the newline at the end of str, 'cause logger provides
if (num > 0 and str[num - 1] == '\n') {
--num;
}
str[num] = '\0'; // Safe. See the c'tor.
// Invoke the OLS custom logging function.
if (g_logger)
Expand Down
4 changes: 2 additions & 2 deletions extern/diskann/diskann.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ set(DISKANN_SOURCES
add_library(diskann STATIC ${DISKANN_SOURCES})
# work
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
target_compile_options(diskann PRIVATE -mavx -msse2 -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors)
target_compile_options(diskann PRIVATE -mavx -msse2 -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DENABLE_CUSTOM_LOGGER=1)
else ()
target_compile_options(diskann PRIVATE -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors)
target_compile_options(diskann PRIVATE -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DENABLE_CUSTOM_LOGGER=1)
endif ()
set_property(TARGET diskann PROPERTY CXX_STANDARD 17)
add_dependencies(diskann boost openblas)
Expand Down
8 changes: 4 additions & 4 deletions src/bitset_impl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ TEST_CASE("roaringbitmap example", "[ut][bitset]") {
r1.setCopyOnWrite(true);

uint32_t compact_size = r1.getSizeInBytes();
std::cout << "size before run optimize " << size << " bytes, and after " << compact_size
<< " bytes." << std::endl;
// std::cout << "size before run optimize " << size << " bytes, and after " << compact_size
// << " bytes." << std::endl;

// create a new bitmap with varargs
Roaring r2 = Roaring::bitmapOf(5, 1, 2, 3, 5, 6);

r2.printf();
printf("\n");
// r2.printf();
// printf("\n");

// create a new bitmap with initializer list
Roaring r2i = Roaring::bitmapOfList({1, 2, 3, 5, 6});
Expand Down
2 changes: 1 addition & 1 deletion src/index/diskann_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,6 @@ TEST_CASE("split building process", "[diskann][ut]") {
}
}
float recall_full = correct / 1000;
std::cout << "Recall: " << recall_full << std::endl;
vsag::logger::debug("Recall: " + std::to_string(recall_full));
REQUIRE(recall_full == recall_partial);
}
8 changes: 8 additions & 0 deletions src/vsag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "vsag/vsag.h"

#include <../extern/diskann/DiskANN/include/logger.h>
#include <cpuinfo.h>

#include <sstream>
Expand All @@ -23,6 +24,13 @@
#include "simd/simd.h"
#include "version.h"

namespace vsag {
std::function<void(diskann::LogLevel, const char*)>
vsag_get_logger() {
return [](diskann::LogLevel, const char* msg) { vsag::logger::debug(msg); };
}
} // namespace vsag

namespace vsag {

std::string
Expand Down
7 changes: 5 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# unittests
file (GLOB_RECURSE UNIT_TESTS "../src/*_test.cpp")
add_executable (unittests ${UNIT_TESTS}
test_main.cpp
fixtures/fixtures.cpp
fixtures/test_logger.cpp
)
if (DIST_CONTAINS_SSE)
target_compile_definitions (unittests PRIVATE ENABLE_SSE=1)
Expand All @@ -17,7 +19,7 @@ if (DIST_CONTAINS_AVX512)
target_compile_definitions (unittests PRIVATE ENABLE_AVX512=1)
endif ()
target_include_directories (unittests PRIVATE "./fixtures")
target_link_libraries (unittests PRIVATE Catch2::Catch2WithMain vsag simd)
target_link_libraries (unittests PRIVATE Catch2::Catch2 vsag simd)
add_dependencies (unittests spdlog Catch2)

# function tests
Expand All @@ -27,9 +29,10 @@ add_executable (functests
fixtures/fixtures.cpp
fixtures/test_dataset.cpp
fixtures/test_dataset_pool.cpp
fixtures/test_logger.cpp
)
target_include_directories (functests PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/spdlog/install/include
${HDF5_INCLUDE_DIRS})
target_link_libraries (functests PRIVATE Catch2::Catch2WithMain vsag simd)
target_link_libraries (functests PRIVATE Catch2::Catch2 vsag simd)
add_dependencies (functests spdlog Catch2)
24 changes: 24 additions & 0 deletions tests/fixtures/test_logger.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "test_logger.h"

#include <catch2/catch_message.hpp>

namespace fixtures {

TestLogger logger;

} // namespace fixtures
86 changes: 86 additions & 0 deletions tests/fixtures/test_logger.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <catch2/catch_message.hpp>
#include <mutex>

#include "vsag/vsag.h"

namespace fixtures {

class TestLogger : public vsag::Logger {
public:
inline void
SetLevel(Level log_level) override {
std::lock_guard<std::mutex> lock(mutex_);
level_ = log_level - vsag::Logger::Level::kTRACE;
}

inline void
Trace(const std::string& msg) override {
std::lock_guard<std::mutex> lock(mutex_);
if (level_ <= 0) {
UNSCOPED_INFO("[test-logger]::[trace] " + msg);
}
}

inline void
Debug(const std::string& msg) override {
std::lock_guard<std::mutex> lock(mutex_);
if (level_ <= 1) {
UNSCOPED_INFO("[test-logger]::[debug] " + msg);
}
}

inline void
Info(const std::string& msg) override {
std::lock_guard<std::mutex> lock(mutex_);
if (level_ <= 2) {
UNSCOPED_INFO("[test-logger]::[info] " + msg);
}
}

inline void
Warn(const std::string& msg) override {
std::lock_guard<std::mutex> lock(mutex_);
if (level_ <= 3) {
UNSCOPED_INFO("[test-logger]::[warn] " + msg);
}
}

inline void
Error(const std::string& msg) override {
std::lock_guard<std::mutex> lock(mutex_);
if (level_ <= 4) {
UNSCOPED_INFO("[test-logger]::[error] " + msg);
}
}

void
Critical(const std::string& msg) override {
std::lock_guard<std::mutex> lock(mutex_);
if (level_ <= 5) {
UNSCOPED_INFO("[test-logger]::[critical] " + msg);
}
}

private:
int64_t level_ = 0;
std::mutex mutex_;
};

extern TestLogger logger;

} // namespace fixtures
31 changes: 31 additions & 0 deletions tests/test_main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <catch2/catch_session.hpp>

#include "./fixtures/test_logger.h"
#include "vsag/vsag.h"

int
main(int argc, char** argv) {
// your setup ...
vsag::Options::Instance().set_logger(&fixtures::logger);

int result = Catch::Session().run(argc, argv);

// your clean-up...

return result;
}
35 changes: 33 additions & 2 deletions tests/test_multi_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
#include <future>
#include <iostream>
#include <nlohmann/json.hpp>
#include <sstream>
#include <thread>

#include "default_logger.h"
#include "fixtures/test_logger.h"
#include "fixtures/thread_pool.h"
#include "vsag/options.h"
#include "vsag/vsag.h"
Expand All @@ -34,16 +37,38 @@ query_knn(std::shared_ptr<vsag::Index> index,
if (result.value()->GetDim() != 0 && result.value()->GetIds()[0] == id) {
return 1.0;
} else {
std::cout << result.value()->GetDim() << " " << result.value()->GetIds()[0] << " " << id
<< std::endl;
std::stringstream ss;
ss << "recall failure: dim " << result.value()->GetDim() << ", id "
<< result.value()->GetIds()[0] << ", expected_id " << id;
fixtures::logger.Debug(ss.str());
}
} else if (result.error().type == vsag::ErrorType::INTERNAL_ERROR) {
std::cerr << "failed to perform knn search on index" << std::endl;
}
return 0.0;
}

// catch2 logger is NOT supported to be used in multi-threading tests, so
// we need to replace it at the start of all the test cases in this file
class LoggerReplacer {
public:
LoggerReplacer() {
origin_logger_ = vsag::Options::Instance().logger();
vsag::Options::Instance().set_logger(&logger_);
}

~LoggerReplacer() {
vsag::Options::Instance().set_logger(origin_logger_);
}

private:
vsag::Logger* origin_logger_;
vsag::DefaultLogger logger_;
};

TEST_CASE("DiskAnn Multi-threading", "[ft][diskann]") {
LoggerReplacer _;

int dim = 65; // Dimension of the elements
int max_elements = 1000; // Maximum number of elements, should be known beforehand
int max_degree = 16; // Tightly connected with internal dimensionality of the data
Expand Down Expand Up @@ -116,6 +141,8 @@ TEST_CASE("DiskAnn Multi-threading", "[ft][diskann]") {
}

TEST_CASE("HNSW Multi-threading", "[ft][hnsw]") {
LoggerReplacer _;

int dim = 16; // Dimension of the elements
int max_elements = 1000; // Maximum number of elements, should be known beforehand
int max_degree = 16; // Tightly connected with internal dimensionality of the data
Expand Down Expand Up @@ -185,6 +212,8 @@ TEST_CASE("HNSW Multi-threading", "[ft][hnsw]") {
}

TEST_CASE("multi-threading read-write test", "[ft][hnsw]") {
LoggerReplacer _;

// avoid too much slow task logs
vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN);

Expand Down Expand Up @@ -278,6 +307,8 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") {
}

TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hnsw]") {
LoggerReplacer _;

// avoid too much slow task logs
vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN);

Expand Down

0 comments on commit b11ff63

Please sign in to comment.