Skip to content

Commit

Permalink
Fix auto transfer for constants
Browse files Browse the repository at this point in the history
Note: auto transfers currently don't handle multi-device, will need fixing.
  • Loading branch information
lukstafi committed Dec 31, 2024
1 parent a58eabb commit a91751b
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,22 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
Option.(join @@ map lowered ~f:(fun optim -> optim.Low_level.merge_node)));
}

let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
let%track3_sexp alloc_if_needed parent_context ~key ~data:node ctx_arrays =
if Tnode.is_in_context_force ~use_host_memory key 43 && not (Map.mem ctx_arrays key) then (
let stream = parent_context.stream in
[%log Tn.debug_name key];
[%log (key : Tnode.t)];
let default () =
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
let dst_ptr =
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
in
(if Utils.settings.automatic_host_transfers && Tn.known_constant key then
match key.array with
| (lazy (Some hosted)) ->
Device.from_host ~dst_ptr ~dst:parent_context hosted;
key.host_modified <- false
| _ -> ());
dst_ptr
in
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
let device = stream.device in
Expand Down Expand Up @@ -423,8 +433,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
~from_prior_context:code.from_prior_context;
let (inputs, outputs), merge_buffer_input = Low_level.input_and_output_nodes code.lowered in
let ctx_arrays =
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays
~f:(alloc_if_needed context.stream)
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays ~f:(alloc_if_needed context)
in
let bindings, schedule = link context code.code ctx_arrays in
let context = make_child ~ctx_arrays context in
Expand All @@ -443,7 +452,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
~f:
(Option.map ~f:(fun l ->
Hashtbl.fold l.Low_level.traced_store ~init:context.ctx_arrays
~f:(alloc_if_needed context.stream)))
~f:(alloc_if_needed context)))
in
let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
Array.fold_mapi schedules ~init:context ~f:(fun i context -> function
Expand Down

0 comments on commit a91751b

Please sign in to comment.