diff --git a/Project.toml b/Project.toml index a453257..a3f935d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,23 @@ name = "LRUCache" uuid = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" -version = "1.6.0" +version = "1.6.1" + +[deps] +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[weakdeps] +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[extensions] +SerializationExt = ["Serialization"] [compat] julia = "1" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random"] +test = ["Test", "Random", "Serialization"] diff --git a/ext/SerializationExt.jl b/ext/SerializationExt.jl new file mode 100644 index 0000000..62e7360 --- /dev/null +++ b/ext/SerializationExt.jl @@ -0,0 +1,76 @@ +module SerializationExt +using LRUCache +using Serialization + +function Serialization.serialize(s::AbstractSerializer, lru::LRU{K, V}) where {K, V} + Serialization.writetag(s.io, Serialization.OBJECT_TAG) + serialize(s, typeof(lru)) + @assert lru.currentsize == length(lru) + serialize(s, lru.currentsize) + serialize(s, lru.maxsize) + serialize(s, lru.hits) + serialize(s, lru.misses) + serialize(s, lru.lock) + serialize(s, lru.by) + serialize(s, lru.finalizer) + for (k, val) in lru + serialize(s, k) + serialize(s, val) + sz = lru.dict[k][3] + serialize(s, sz) + end +end + +function Serialization.deserialize(s::AbstractSerializer, ::Type{LRU{K, V}}) where {K, V} + currentsize = Serialization.deserialize(s) + maxsize = Serialization.deserialize(s) + hits = Serialization.deserialize(s) + misses = Serialization.deserialize(s) + lock = Serialization.deserialize(s) + by = Serialization.deserialize(s) + finalizer = Serialization.deserialize(s) + + dict = Dict{K, Tuple{V, LRUCache.LinkedNode{K}, Int}}() + sizehint!(dict, currentsize) + # Create node chain + first = nothing + node = nothing + for i in 1:currentsize + prev = node + k = deserialize(s) + node = LRUCache.LinkedNode{K}(k) + val = deserialize(s) + sz = deserialize(s) + dict[k] = (val, node, sz) + if i == 1 + first = node + continue + else + prev.next = node + node.prev = prev + end + end + # close the chain if any node exists + if node !== nothing + node.next = first + first.prev = node + end + + # Createa cyclic ordered set from the node chain + keyset = LRUCache.CyclicOrderedSet{K}() + keyset.first = first + keyset.length = currentsize + + # Create the LRU + lru = LRU{K,V}(maxsize=maxsize) + lru.dict = dict + lru.keyset = keyset + lru.currentsize = currentsize + lru.hits = hits + lru.misses = misses + lru.lock = lock + lru.by = by + lru.finalizer = finalizer + lru +end +end diff --git a/src/LRUCache.jl b/src/LRUCache.jl index a9927b4..ee53067 100644 --- a/src/LRUCache.jl +++ b/src/LRUCache.jl @@ -316,4 +316,8 @@ function _finalize_evictions!(finalizer, evictions) return end +if !isdefined(Base, :get_extension) + include("../ext/SerializationExt.jl") +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index b30b5eb..7c38007 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -297,3 +297,4 @@ end end include("originaltests.jl") +include("serializationtests.jl") diff --git a/test/serializationtests.jl b/test/serializationtests.jl new file mode 100644 index 0000000..1364750 --- /dev/null +++ b/test/serializationtests.jl @@ -0,0 +1,51 @@ +using Serialization +@testset "Large Serialize and Deserialize" begin + + cache = LRU{Int, Int}(maxsize=100_000) + + # Populate the cache with dummy data + num_entries_to_test = [0, 1, 2, 3, 4, 5, 100_000, 1_000_000] + for i in 0:maximum(num_entries_to_test) + # Add dummy data on all but the first iteration, + # to test an empty cache + i > 0 && (cache[i] = i+1) + i ∈ num_entries_to_test || continue + io = IOBuffer() + serialize(io, cache) + seekstart(io) + deserialized_cache = deserialize(io) + + # Check that the cache is the same + @test cache.maxsize == deserialized_cache.maxsize + @test cache.currentsize == deserialized_cache.currentsize + @test cache.hits == deserialized_cache.hits + @test cache.misses == deserialized_cache.misses + @test cache.by == deserialized_cache.by + @test cache.finalizer == deserialized_cache.finalizer + @test cache.keyset.length == deserialized_cache.keyset.length + @test issetequal(collect(cache), collect(deserialized_cache)) + # Check that the cache has the same keyset + @test length(cache.keyset) == length(deserialized_cache.keyset) + @test all(((c_val, d_val),) -> c_val == d_val, zip(cache.keyset, deserialized_cache.keyset)) + # Check that the cache keys, values, and sizes are the same + for (key, (c_value, c_node, c_s)) in cache.dict + d_value, d_node, d_s = deserialized_cache.dict[key] + c_value == d_value || @test false + c_node.val == d_node.val || @test false + c_s == d_s || @test false + end + end +end + +@testset "Serialize mutable references" begin + lru = LRU(; maxsize=5) + a = b = [1] + lru[1] = a + lru[2] = b + @test lru[1] === lru[2] + io = IOBuffer() + serialize(io, lru) + seekstart(io) + lru2 = deserialize(io) + @test lru2[1] === lru2[2] +end