Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement iterators that mirror Stdlib.Array #125

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/domainslib.ml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
module Parray = Parray
module Chan = Chan
module Task = Task
8 changes: 8 additions & 0 deletions lib/multi_channel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type 'a t = {
}

let dls_make_key () =
let open Stdlib in
Domain.DLS.new_key (fun () ->
{
id = -1;
Expand All @@ -57,6 +58,7 @@ let rec log2 n =
if n <= 1 then 0 else 1 + (log2 (n asr 1))

let make ?(recv_block_spins = 2048) n =
let open Stdlib in
{ channels = Array.init n (fun _ -> Ws_deque.create ());
foreign_queue = Foreign_queue.create ();
waiters = Chan.make_unbounded ();
Expand All @@ -67,11 +69,13 @@ let make ?(recv_block_spins = 2048) n =

let register_domain mchan =
let id = Atomic.fetch_and_add mchan.next_domain_id 1 in
let open Stdlib in
assert(id < Array.length mchan.channels);
id

let init_domain_state mchan dls_state =
let id = register_domain mchan in
let open Stdlib in
let len = Array.length mchan.channels in
dls_state.id <- id;
dls_state.steal_offsets <- Array.init (len - 1) (fun i -> (id + i + 1) mod len);
Expand All @@ -81,6 +85,7 @@ let init_domain_state mchan dls_state =
let get_local_state mchan =
let dls_state = Domain.DLS.get mchan.dls_key in
if dls_state.id >= 0 then begin
let open Stdlib in
assert (dls_state.id < Array.length mchan.channels);
dls_state
end
Expand Down Expand Up @@ -120,11 +125,13 @@ let send_foreign mchan v =

let send mchan v =
let id = (get_local_state mchan).id in
let open Stdlib in
Ws_deque.push (Array.unsafe_get mchan.channels id) v;
check_waiters mchan

let rec recv_poll_loop mchan dls cur_offset =
let offsets = dls.steal_offsets in
let open Stdlib in
let k = (Array.length offsets) - cur_offset in
if k = 0 then raise Exit
else begin
Expand All @@ -144,6 +151,7 @@ let rec recv_poll_loop mchan dls cur_offset =

let recv_poll_with_dls mchan dls =
try
let open Stdlib in
Ws_deque.pop (Array.unsafe_get mchan.channels dls.id)
with
| Exit ->
Expand Down
31 changes: 31 additions & 0 deletions lib/parray.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
type 'a t = 'a array

(** {1 Iterators} *)

let iter f a p =
let n = Stdlib.Array.length a in
Task.parallel_for p ~start:0 ~finish:(n-1) ~body:(fun i -> f @@ Stdlib.Array.get a i)

let iteri f a p =
let n = Stdlib.Array.length a in
Task.parallel_for p ~start:0 ~finish:(n-1) ~body:(fun i -> f i @@ Stdlib.Array.get a i)

let map f a p =
let n = Stdlib.Array.length a in
let res = Stdlib.Array.make n @@ f (Stdlib.Array.get a 0) in
Task.parallel_for p ~start:0 ~finish:(n-1) ~body:(fun i -> Stdlib.Array.set res i @@ f (Stdlib.Array.get a i));
res

let map_inplace f a p =
let n = Stdlib.Array.length a in
Task.parallel_for p ~start:0 ~finish:(n-1) ~body:(fun i -> Stdlib.Array.set a i @@ f (Stdlib.Array.get a i))

let mapi f a p =
let n = Stdlib.Array.length a in
let res = Stdlib.Array.make n @@ f 0 (Stdlib.Array.get a 0) in
Task.parallel_for p ~start:0 ~finish:(n-1) ~body:(fun i -> Stdlib.Array.set res i @@ f i (Stdlib.Array.get a i));
res

let mapi_inplace f a p =
let n = Stdlib.Array.length a in
Task.parallel_for p ~start:0 ~finish:(n-1) ~body:(fun i -> Stdlib.Array.set a i @@ f i (Stdlib.Array.get a i))
33 changes: 33 additions & 0 deletions lib/parray.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
type 'a t = 'a array

(** {1 Iterators} *)

val iter : ('a -> unit) -> 'a array -> Task.pool -> unit
(** [iter f a p] applies function [f] in turn to all
the elements of [a]. It is equivalent to
[f a.(0); f a.(1); ...; f a.(length a - 1); ()]. *)

val iteri : (int -> 'a -> unit) -> 'a array -> Task.pool -> unit
(** Same as {!iter}, but the
function is applied to the index of the element as first argument,
and the element itself as second argument. *)

val map : ('a -> 'b) -> 'a array -> Task.pool -> 'b array
(** [map f a p] applies function [f] to all the elements of [a],
and builds an array with the results returned by [f]:
[[| f a.(0); f a.(1); ...; f a.(length a - 1) |]]. *)

val map_inplace : ('a -> 'a) -> 'a array -> Task.pool -> unit
(** [map_inplace f a p] applies function [f] to all elements of [a],
and updates their values in place.
@since 5.1 *)

val mapi : (int -> 'a -> 'b) -> 'a array -> Task.pool -> 'b array
(** Same as {!map}, but the
function is applied to the index of the element as first argument,
and the element itself as second argument. *)

val mapi_inplace : (int -> 'a -> 'a) -> 'a array -> Task.pool -> unit
(** Same as {!map_inplace}, but the function is applied to the index of the
element as first argument, and the element itself as second argument.
@since 5.1 *)
5 changes: 5 additions & 0 deletions test/dune
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@
(libraries domainslib)
(modules test_parallel_find))

(test
(name test_parallel_array)
(libraries domainslib)
(modules test_parallel_array))

(test
(name test_parallel_scan)
(libraries domainslib)
Expand Down
43 changes: 43 additions & 0 deletions test/test_parallel_array.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
(* Generic tests for the parray module *)

open Domainslib

let test_map_inplace pool =
let a = [| 1 ; 2 ; 3 ; 4 |] in
Parray.map_inplace (fun x -> 2 * x) a pool;
let res = [| 2 ; 4 ; 6 ; 8 |] in
assert (a = res)

let test_mapi_inplace pool =
let a = [| 1 ; 2 ; 3 ; 4 |] in
Parray.mapi_inplace (fun _ x -> 2 * x) a pool;
let res = [| 2 ; 4 ; 6 ; 8 |] in
assert (a = res)

let test_map pool =
let a = [| 1 ; 2 ; 3 ; 4 |] in
let b = Parray.map (fun x -> 2 * x) a pool in
let res = [| 2 ; 4 ; 6 ; 8 |] in
assert (b = res)

let test_mapi pool =
let a = [| 1 ; 2 ; 3 ; 4 |] in
let b = Parray.mapi (fun _ x -> 2 * x) a pool in
let res = [| 2 ; 4 ; 6 ; 8 |] in
assert (b = res)


let () =
(* [num_domains] is the number of *new* domains spawned by the pool
performing computations in addition to the current domain. *)
let num_domains = Domain.recommended_domain_count () - 1 in
Printf.eprintf "Test parray on %d domains.\n" (num_domains + 1);
let pool = Task.setup_pool ~num_domains ~name:"pool" () in
Task.run pool begin fun () ->
test_map pool;
test_map_inplace pool;
test_mapi pool;
test_mapi_inplace pool;
end;
Task.teardown_pool pool;
prerr_endline "Success.";