Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for serialization of large caches #46

Merged
merged 18 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
name = "LRUCache"
uuid = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
version = "1.6.0"
version = "1.6.1"

[compat]
julia = "1"
julia = "≥ 1.9"
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "Serialization"]

[weakdeps]
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[extensions]
SerializationExt = ["Serialization"]
83 changes: 83 additions & 0 deletions ext/SerializationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
module SerializationExt
export serialize, deserialize
jarbus marked this conversation as resolved.
Show resolved Hide resolved
using LRUCache
using Serialization

# Serialization of large LRUs causes a stack overflow error, so we
# create a custom serializer that represents LinkedNodes as Ints
function Serialization.serialize(s::AbstractSerializer, lru::LRU{K, V}) where {K, V}
# Create a mapping from memory address to id
node_map = Dict{Ptr, Int}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than using Ptr's, I think it is simpler (and maybe more robust) to use an IdDict:

Suggested change
# Create a mapping from memory address to id
node_map = Dict{Ptr, Int}()
# Create a mapping from object to id. Here we use `IdDict` to use object identity as the hash.
node_map = IdDict{LinkedNode{K}, Int}()

then this can be indexed just as node_map[node] = id rather than using pointer_from_objref.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks

# Create mapping for first node
id = 1
first_node = node = lru.keyset.first
node_map[pointer_from_objref(node)] = id
# Go through the rest of the nodes in the cycle and create a mapping
node = node.next
while node != first_node
id += 1
node_map[pointer_from_objref(node)] = id
node = node.next
end
@assert id == length(lru) == lru.keyset.length == length(lru.dict)
# By this point, the first node has id 1 and the last node has id length(lru)
# so when deserializing, we can infer the order by the id
# Create the dict with ids instead of nodes
dict = Dict{K, Tuple{V, Int, Int}}()
for (key, (value, node, s)) in lru.dict
id = node_map[pointer_from_objref(node)]
dict[key] = (value, id, s)
end
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
Serialization.serialize(s, typeof(lru))
Serialization.serialize(s, dict)
Serialization.serialize(s, lru.currentsize)
Serialization.serialize(s, lru.maxsize)
Serialization.serialize(s, lru.hits)
Serialization.serialize(s, lru.misses)
Serialization.serialize(s, lru.lock)
Serialization.serialize(s, lru.by)
Serialization.serialize(s, lru.finalizer)
end

function Serialization.deserialize(s::AbstractSerializer, ::Type{LRU{K, V}}) where {K, V}
dict_with_ids = Serialization.deserialize(s)
currentsize = Serialization.deserialize(s)
maxsize = Serialization.deserialize(s)
hits = Serialization.deserialize(s)
misses = Serialization.deserialize(s)
lock = Serialization.deserialize(s)
by = Serialization.deserialize(s)
finalizer = Serialization.deserialize(s)
# Create a new keyset and mapping from id to node
n_nodes = length(dict_with_ids)
nodes = Vector{LRUCache.LinkedNode{K}}(undef, n_nodes)
dict = Dict{K, Tuple{V, LRUCache.LinkedNode{K}, Int}}()
# Create the nodes, but don't link them yet
for (key, (value, id, s)) in dict_with_ids
nodes[id] = LRUCache.LinkedNode{K}(key)
dict[key] = (value, nodes[id], s)
end
# Link the nodes
for (idx, node) in enumerate(nodes)
node.next = nodes[idx % n_nodes + 1]
node.prev = nodes[idx == 1 ? n_nodes : idx - 1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can write these two right hand sides using mod1(idx+1, n_nodes) and mod1(idx-1, n_nodes), which I would find more clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't know about the mod1 function, thanks! That's much cleaner.

end
# Create keyset with first node and n_nodes
keyset = LRUCache.CyclicOrderedSet{K}()
keyset.first = nodes[1]
keyset.length = n_nodes
# Create LRU
lru = LRU{K,V}(maxsize=maxsize)
lru.dict = dict
lru.keyset = keyset
lru.currentsize = currentsize
lru.hits = hits
lru.misses = misses
lru.lock = lock
lru.by = by
lru.finalizer = finalizer
lru
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,4 @@ end
end

include("originaltests.jl")
include("serializationtests.jl")
39 changes: 39 additions & 0 deletions test/serializationtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Serialization
@testset "Serialize and Deserialize" begin

cache = LRU{Int, Int}(maxsize=100_000)

# Populate the cache with dummy data
for i in 1:1_000_000
cache[i] = i+1
end
serialize("cache.jls", cache)
deserialized_cache = deserialize("cache.jls")
rm("cache.jls")
jarbus marked this conversation as resolved.
Show resolved Hide resolved

# Check that the cache is the same
@test cache.maxsize == deserialized_cache.maxsize
@test cache.currentsize == deserialized_cache.currentsize
@test cache.hits == deserialized_cache.hits
@test cache.misses == deserialized_cache.misses
@test cache.by == deserialized_cache.by
@test cache.finalizer == deserialized_cache.finalizer
@test cache.keyset.length == deserialized_cache.keyset.length
@test length(cache.dict) == length(deserialized_cache.dict)
jarbus marked this conversation as resolved.
Show resolved Hide resolved
# Check that the cache has the same keyset
c_node = cache.keyset.first
d_node = deserialized_cache.keyset.first
for i in 1:cache.keyset.length
c_node.val == d_node.val || @test false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply @test c_node.val == d_node.val ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just personal taste, for a large dict of 100k entries, I don't want to add 100k tests that compare each element. I just want to have one test that fails if any element is different, which I believe these do

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that is true. On the other hand, the test could also have just been with a handful of elements in the LRU cache I believe. I haven't timed it and believe this is all very fast because it is just Ints, but the 100000 did jump to the eye as some large number to have as a simple test case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for small caches, the regular serialization method works fine and doesn't stackoverflow, so we need a big one to really test the new method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ exactly right

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes and no, since the new method is now always called, irrespective of the size of the cache. And it is written in such a way that it should not suffer from the same flaws, so I would think that if it passes the test for a smaller cache, that provides sufficient guarantees. But I am fine with the current tests.

Is it clear what was causing the default serialisation strategy to fail? My guess of why it was entering an infinite loop would apply irrespective of the size of the cache, so that is not consistent with the original method working for small caches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe as a middle ground, the following gives a single (two) test(s), but still tests all values are equal:

@test length(cache.keyset) == length(deserialized_cache.keyset)
@test all(((c_val, d_val),) -> c_val == d_val, zip(cache.keyset, deserialized_cache.keyset))

c_node = c_node.next
d_node = d_node.next
end
# Check that the cache keys, values, and sizes are the same
for (key, (c_value, c_node, c_s)) in cache.dict
d_value, d_node, d_s = deserialized_cache.dict[key]
c_value == d_value || @test false
c_node.val == d_node.val || @test false
c_s == d_s || @test false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for these tests.

end
end

Loading