-
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #16 #17
base: main
Are you sure you want to change the base?
Fix #16 #17
Conversation
@test_broken collect(A * y) ≈ collect(A) * collect(y) | ||
|
||
@test_broken gradient(A -> sum(abs, A * y), A)[1] isa CuArray # gather!(dst::JLArray, ...) fails |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the solution here to use gather
(and take on a dep)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs gather/scatter from NNlibCUDA to work on the GPU. And since there's no corresponding code for non-CuArray GPUArrays, I think it can't work with this fake JLArray.
For testing it, you could set up the whole buildkite story to run honest CUDA tests. But perhaps it's not worth it, and this package should just trust NNlib + NNlibCUDA to test things. And perhaps Flux to test the integration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Xref FluxML/NNlib.jl#427 too --- it would be nice if forgetting to load NNlibCUDA gave friendly errors, not scalar indexing.
It would be nicer if that could be loaded automatically, of course.
for x in data | ||
isnothing(_findval(x, labels)) && error("Value $x not found in labels") | ||
isnothing(_findval(x, labels)) && throw(ArgumentError("Value x = $x not found in labels = $labels")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed these error types partly so that tests can distinguish scalar indexing errors from helpful messages.
end | ||
end | ||
return OneHotArray(indices, length(labels)) | ||
end | ||
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One reason to change this is to avoid ever making a MVector
or something weird like that:
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too | |
function _onehotbatch(data::AbstractGPUArray, labels) |
indices = UInt32[something(_findval(x, labels), default_index) for x in data] | ||
return OneHotArray(indices, length(labels)) | ||
end | ||
function _onehotbatch(data::AbstractArray, labels, default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function _onehotbatch(data::AbstractArray, labels, default) | |
function _onehotbatch(data::AbstractGPUArray, labels, default) |
Unlike FluxML/Flux.jl#1959, this uses
map
over arrays. Some duplication, unfortunately. Possibly the new method should be restricted to AbstractGPUArrays?Closes #16
Also tries to organise the tests just a little bit better.