Representation of computation graphs
This section of the manual documents the inner workings of the graph computation functions in the source file src/compute.jl.
The timing information reported here is obtained by running Documenter
's @example
triggered by a Github action. Because of this, there is no control over the hardware used and consequently the timing values that appear in this document are not very reliable; especially in what regards Parallel computations.
Computation nodes
An expression like A*x+b
contains
- three nodes
A
,b
, andx
that corresponds to variables, and - two computation nodes, one for the multiplication and the other for the addition.
Since we are aiming for allocation-free computation, we start by pre-allocating memory for all nodes
using Random
A = rand(Float64,400,30) # pre-allocated storage for the variable A
x = rand(Float64,30) # pre-allocated storage for the variable b
b = rand(Float64,400) # pre-allocated storage for the variable x
Ax = similar(b) # pre-allocated storage for the computation node A*x
Axb = similar(b) # pre-allocated storage for the computation node A*x+b
and associate to the following functions to the two computation nodes:
using LinearAlgebra
function node_Ax!(out::Vector{F},in1::Matrix{F}, in2::Vector{F}) where {F}
mul!(out,in1,in2)
end
function node_Axb!(out::Vector{F},in1::Vector{F}, in2::Vector{F}) where {F}
@. out = in1 + in2
end
It would be temping to construct the computation graph out of such functions. However, every function in julia has is own unique type (all subtypes of the Function
abstract type). This is problematic because we will often need to iterate over the nodes of a graph, e.g., to re-evaluate all nodes in the graph or just the parents of a specific node. If all nodes have a unique type, then such iterations not be type-stable.
To resolve this issue we do two "semantic" transformations to the functions above: function closure and function wrapping with the package FunctionWrappers.
Function closure
Function closure allow us to obtain functions for all the nodes that "look the same" in the following sense:
- they all have they have the same signature (i.e., same number of input parameters and with the same types), and
- they all return a value of the same type.
Specifically, we "capture" the input parameters for the two computation nodes, which makes them look like parameter-free functions that return nothing:
@inline node_Ax_closed!() = let Ax=Ax , A=A, x=x
node_Ax!(Ax,A,x)
nothing
end
@inline node_Axb_closed!() = let Axb=Axb, Ax=Ax, b=b
node_Axb!(Axb,Ax,b)
nothing
end
See Performance tips on the performance of captured variables on the use of let
, which essentially helps the parser by "fixing" the captured variable. To be precise "fixing" the arrays, but not the values of their entries.
Function wrapping
Even though all node functions now have similar inputs and outputs, they are still not of the same type (as far as julia is concerned). To fix this issue, we use the package FunctionWrappers to create a type-stable wrapper:
import ComputationGraphs
node_Ax_wrapped = ComputationGraphs.FunctionWrapper(node_Ax_closed!)
node_Axb_wrapped = ComputationGraphs.FunctionWrapper(node_Axb_closed!)
The "wrapped" functions can be called directly with:
begin # hide
node_Ax_wrapped()
node_Axb_wrapped()
nothing # hide
end # hide
or a little faster with
ComputationGraphs.do_ccall(node_Ax_wrapped)
ComputationGraphs.do_ccall(node_Axb_wrapped)
The code above does not actually use FunctionWrappers; instead it uses a very simplified version of FunctionWrappers that can only wrap functions with no arguments that always return nothing
.
To use FunctionWrappers, we would have used instead
import FunctionWrappers
node_Ax_wrapped_FW = FunctionWrappers.FunctionWrapper{Nothing,Tuple{}}(node_Ax_closed!)
node_Axb_wrapped_FW = FunctionWrappers.FunctionWrapper{Nothing,Tuple{}}(node_Axb_closed!)
and the functions would be called with
begin # hide
FunctionWrappers.do_ccall(node_Ax_wrapped_FW, ())
FunctionWrappers.do_ccall(node_Axb_wrapped_FW, ())
nothing # hide
end # hide
Verification
We can now check the fruits of our work.
- Type stability?
println("Type stability for original: ", typeof(node_Ax!)==typeof(node_Axb!))
println("Type stability for wrapped : ", typeof(node_Ax_wrapped)==typeof(node_Axb_wrapped))
Type stability for original: false
Type stability for wrapped : true
- Correctness?
rand!(A)
rand!(b)
rand!(x)
# the original functions
node_Ax!(Ax,A,x)
node_Axb!(Axb,Ax,b)
println("Correctness for original: ", Axb==(A*x+b))
rand!(A)
rand!(b)
rand!(x)
# the new functions
ComputationGraphs.do_ccall(node_Ax_wrapped)
ComputationGraphs.do_ccall(node_Axb_wrapped)
println("Correctness for wrapped : ", Axb==(A*x+b))
Correctness for original: true
Correctness for wrapped : true
- Speed?
using BenchmarkTools, Printf
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level
bmk1 = @benchmark begin
node_Ax!($Ax,$A,$x)
node_Axb!($Axb,$Ax,$b)
end evals=1000 samples=10000
bmk3 = @benchmark begin
node_Ax_closed!()
node_Axb_closed!()
end evals=1000 samples=10000
bmk2 = @benchmark begin
ComputationGraphs.do_ccall($node_Ax_wrapped)
ComputationGraphs.do_ccall($node_Axb_wrapped)
end evals=1000 samples=10000
@printf("Overhead due to closure = %3.f ns\n",median(bmk3.times)-median(bmk1.times))
@printf("Overhead due to wrapping = %3.f ns\n",median(bmk2.times)-median(bmk3.times))
@printf("Total overhead = %3.f ns\n",median(bmk2.times)-median(bmk1.times))
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
Original:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 1.162 μs … 2.461 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.195 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.196 μs ± 43.078 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▃▇█▆▃▁
▁▂▄▅▅▅▄▃▃▃▃▃▄▃▃▃▃▃▂▂▂▇██████▇▇▇█▆▆▄▅▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
1.16 μs Histogram: frequency by time 1.24 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Closure:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 1.197 μs … 2.144 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.223 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.219 μs ± 18.150 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▄▇▄▄▃▁ ▁▃▄▃▂▂▃▂▁▂▁▁▁▁ ▁▁ ▁█▇▅▅▃▂▂▂▆▅▃▃▃▃▂▁ ▁▁▁ ▂
██████████████████████████▇█████████████████▇█▇████▇▇▆▆▅▅▇ █
1.2 μs Histogram: log(frequency) by time 1.25 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Wrapped:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 1.207 μs … 2.489 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.237 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.238 μs ± 36.247 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▁▄█
▁▄▆▃▂▁▁▂▃▂▂▂▂▂▁▁▁▁▂▂▁▁▂████▆▄▃▄▄▇▅▄▄▃▂▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
1.21 μs Histogram: frequency by time 1.27 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Overhead due to closure = 28 ns
Overhead due to wrapping = 15 ns
Total overhead = 43 ns
This shows that closure and wrapping do introduce a small overhead (tens of ns). However, the benefits of type stability will appear when we start iterating over nodes. To see this consider the following function that evaluates a set of nodes:
function compute_all!(nodes::Vector{Function})
for node in nodes
node()
end
end
function compute_all_wrapped!(nodes::Vector{ComputationGraphs.FunctionWrapper})
for node::ComputationGraphs.FunctionWrapper in nodes
ComputationGraphs.do_ccall(node)
end
end
We can use @code_warntype
to see how wrapping helps in terms of type stability:
# using just closure
nodes_closed=repeat([node_Ax_closed!,node_Axb_closed!],outer=5)
@show typeof(nodes_closed)
InteractiveUtils.@code_warntype compute_all!(nodes_closed)
# using closure+wrapped
nodes_wrapped=repeat([node_Ax_wrapped,node_Axb_wrapped],outer=5)
@show typeof(nodes_wrapped)
InteractiveUtils.@code_warntype compute_all_wrapped!(nodes_wrapped)
typeof(nodes_closed) = Vector{Function}
MethodInstance for Main.compute_all!(::Vector{Function})
from compute_all!(nodes::Vector{Function}) @ Main lib_representation.md:229
Arguments
#self#::Core.Const(Main.compute_all!)
nodes::Vector{Function}
Locals
@_3::Union{Nothing, Tuple{Function, Int64}}
node::Function
Body::Nothing
1 ─ %1 = nodes::Vector{Function}
│ (@_3 = Base.iterate(%1))
│ %3 = @_3::Union{Nothing, Tuple{Function, Int64}}
│ %4 = (%3 === nothing)::Bool
│ %5 = Base.not_int(%4)::Bool
└── goto #4 if not %5
2 ┄ %7 = @_3::Tuple{Function, Int64}
│ (node = Core.getfield(%7, 1))
│ %9 = Core.getfield(%7, 2)::Int64
│ %10 = node::Function
│ (%10)()
│ (@_3 = Base.iterate(%1, %9))
│ %13 = @_3::Union{Nothing, Tuple{Function, Int64}}
│ %14 = (%13 === nothing)::Bool
│ %15 = Base.not_int(%14)::Bool
└── goto #4 if not %15
3 ─ goto #2
4 ┄ return nothing
typeof(nodes_wrapped) = Vector{ComputationGraphs.FunctionWrapper}
MethodInstance for Main.compute_all_wrapped!(::Vector{ComputationGraphs.FunctionWrapper})
from compute_all_wrapped!(nodes::Vector{ComputationGraphs.FunctionWrapper}) @ Main lib_representation.md:234
Arguments
#self#::Core.Const(Main.compute_all_wrapped!)
nodes::Vector{ComputationGraphs.FunctionWrapper}
Locals
@_3::Union{Nothing, Tuple{ComputationGraphs.FunctionWrapper, Int64}}
node::ComputationGraphs.FunctionWrapper
@_5::ComputationGraphs.FunctionWrapper
Body::Nothing
1 ─ %1 = nodes::Vector{ComputationGraphs.FunctionWrapper}
│ (@_3 = Base.iterate(%1))
│ %3 = @_3::Union{Nothing, Tuple{ComputationGraphs.FunctionWrapper, Int64}}
│ %4 = (%3 === nothing)::Bool
│ %5 = Base.not_int(%4)::Bool
└── goto #7 if not %5
2 ┄ Core.NewvarNode(:(node))
│ %8 = @_3::Tuple{ComputationGraphs.FunctionWrapper, Int64}
│ %9 = Core.getfield(%8, 1)::ComputationGraphs.FunctionWrapper
│ %10 = ComputationGraphs.FunctionWrapper::Core.Const(ComputationGraphs.FunctionWrapper)
│ (@_5 = %9)
│ %12 = @_5::ComputationGraphs.FunctionWrapper
│ %13 = (%12 isa %10)::Core.Const(true)
└── goto #4 if not %13
3 ─ goto #5
4 ─ Core.Const(:(@_5))
│ Core.Const(:(Base.convert(%10, %16)))
└── Core.Const(:(@_5 = Core.typeassert(%17, %10)))
5 ┄ %19 = @_5::ComputationGraphs.FunctionWrapper
│ (node = %19)
│ %21 = Core.getfield(%8, 2)::Int64
│ %22 = ComputationGraphs.do_ccall::Core.Const(ComputationGraphs.do_ccall)
│ %23 = node::ComputationGraphs.FunctionWrapper
│ (%22)(%23)
│ (@_3 = Base.iterate(%1, %21))
│ %26 = @_3::Union{Nothing, Tuple{ComputationGraphs.FunctionWrapper, Int64}}
│ %27 = (%26 === nothing)::Bool
│ %28 = Base.not_int(%27)::Bool
└── goto #7 if not %28
6 ─ goto #2
7 ┄ return nothing
These specific functions compute_all!
and compute_all_wrapped!
are so simple that type instability actually does not lead to heap allocations, but the use of wrapped functions still leads to slightly faster code.
@show typeof(nodes_closed)
bmk3 = @benchmark compute_all!($nodes_closed) evals=1 samples=10000
@show typeof(nodes_wrapped)
bmk2 = @benchmark compute_all_wrapped!($nodes_wrapped) evals=1 samples=10000
typeof(nodes_closed) = Vector{Function}
Closure:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 6.231 μs … 29.465 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 6.321 μs ┊ GC (median): 0.00%
Time (mean ± σ): 6.431 μs ± 880.259 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▇█▃ ▂ ▁
███▅▅▄▃▁▃▄▅▆▃█▅▁▁▃▁▁▄▁▃▄▄▁▁▁▁▁▁▁▃▁▅▅▄▃▃▅▃▁▄▅▃▁▅▃▄▁▃▁▁▁▁▁▁▁▇ █
6.23 μs Histogram: log(frequency) by time 10.8 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
typeof(nodes_wrapped) = Vector{ComputationGraphs.FunctionWrapper}
Closure+Wrapping:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 5.951 μs … 33.312 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 6.152 μs ┊ GC (median): 0.00%
Time (mean ± σ): 6.270 μs ± 941.613 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█
▅▄██▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
5.95 μs Histogram: frequency by time 9.82 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computations
So far we discussed how to compute all nodes or some give vector of nodes. Restricting evaluations to just the set of nodes that need to be recomputed requires introducing some simple logic to the function closures.
Implementation
To support need-based evaluations, we use a BitVector
to keep track of which nodes have been evaluated. For our 2-node example, we would use
validValue=falses(2)
The functions below now include the logic for need-based evaluation:
node_Ax_conditional_closed() = let validValue=validValue,
Ax=Ax , A=A, x=x
node_Ax!(Ax,A,x) # this node's computation
nothing
end
node_Ax_conditional_wrapped = ComputationGraphs.FunctionWrapper(node_Ax_conditional_closed)
node_Axb_conditional_closed() = let validValue=validValue,
Axb=Axb, Ax=Ax, b=b,
node_Ax_conditional_wrapped=node_Ax_conditional_wrapped
# compute parent node Ax (if needed)
if !validValue[1]
validValue[1]=true
ComputationGraphs.do_ccall(node_Ax_conditional_wrapped)
end
node_Axb!(Axb,Ax,b) # this nodes' computation
nothing
end
node_Axb_conditional_wrapped = ComputationGraphs.FunctionWrapper(node_Axb_conditional_closed)
With this logic, we only need a call to evaluate the node A*x+b
, as this will automatically trigger the evaluation of A*x
(if needed). To check that the logic is working, we do:
begin # hide
fill!(validValue,false)
fill!(Ax,0.0)
fill!(Axb,0.0)
ComputationGraphs.do_ccall(node_Ax_conditional_wrapped)
@assert validValue == [false,false] "no parent computed"
@assert all(Ax .== A*x) "should only compute Ax"
@assert all(Axb .== 0) "should only compute Ax"
fill!(validValue,false)
fill!(Ax,0.0)
fill!(Axb,0.0)
ComputationGraphs.do_ccall(node_Axb_conditional_wrapped)
@assert validValue == [true,false] "parent should have been computed"
@assert all(Ax .== A*x) "should compute both"
@assert all(Axb .== A*x+b) "should compute both"
nothing # hide
end # hide
Timing verification
We can now check the impact of the new logic on timing.
using BenchmarkTools, Printf
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level
bmk1 = @benchmark begin
node_Ax!($Ax,$A,$x)
node_Axb!($Axb,$Ax,$b)
end evals=1000 samples=10000
bmk2a = @benchmark begin
ComputationGraphs.do_ccall($node_Ax_wrapped)
ComputationGraphs.do_ccall($node_Axb_wrapped)
end evals=1000 samples=10000
bmk2b = @benchmark begin
$validValue[1]=false
$validValue[2]=false
if !$validValue[2]
$validValue[2]=true
ComputationGraphs.do_ccall($node_Axb_conditional_wrapped)
end
end evals=1000 samples=10000
bmk3 = @benchmark begin
if !$validValue[2]
$validValue[2]=true
ComputationGraphs.do_ccall($node_Axb_conditional_wrapped)
end
end evals=1 samples=10000
bmk4 = @benchmark begin
$validValue[2]=false
if !$validValue[2]
$validValue[2]=true
ComputationGraphs.do_ccall($node_Axb_conditional_wrapped)
end
end evals=1000 samples=10000
@printf("overhead due to closure+wrapping for full computations = %+6.f ns\n",
median(bmk2a.times)-median(bmk1.times))
@printf("overhead due to closure+wrapping+logic for full computations = %+6.f ns\n",
median(bmk2b.times)-median(bmk1.times))
# @printf("overhead due just to logic for full computations = %+6.f ns\n", # hide
@printf("overhead due to closure+wrapping+logic for for computations = %+6.f ns (<0 means savings)\n",
median(bmk3.times)-median(bmk1.times))
@printf("overhead due to closure+wrapping+logic for partial computations = %+6.f ns (<0 means savings)\n",
median(bmk4.times)-median(bmk1.times))
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
Unconditional computation:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 1.152 μs … 1.846 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.188 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.188 μs ± 17.270 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█▇▇▅▄▃▄▁
▂▃▃▃▃▄▃▂▁▁▂▂▃▂▂▂▁▁▂▂▁▁▁▂▄▅████████▅▅▄▅▇▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
1.15 μs Histogram: frequency by time 1.22 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Unconditional computation with wrapping:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 5.951 μs … 33.312 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 6.152 μs ┊ GC (median): 0.00%
Time (mean ± σ): 6.270 μs ± 941.613 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█
▅▄██▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
5.95 μs Histogram: frequency by time 9.82 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computation, but with all valid=false:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 5.951 μs … 33.312 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 6.152 μs ┊ GC (median): 0.00%
Time (mean ± σ): 6.270 μs ± 941.613 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█
▅▄██▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
5.95 μs Histogram: frequency by time 9.82 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computation, with full reuse:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 29.000 ns … 9.678 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 30.000 ns ┊ GC (median): 0.00%
Time (mean ± σ): 32.357 ns ± 96.627 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█
▃▁▁▁█▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▂ ▂
29 ns Histogram: frequency by time 41 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computation, with valid=false only for Axb:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
Range (min … max): 98.883 ns … 181.697 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 99.705 ns ┊ GC (median): 0.00%
Time (mean ± σ): 100.838 ns ± 3.906 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▆██▆▃▂ ▂▄▃▂▁▁ ▂
████████▆▇█▇▇▇▅▅▆▁▃▃▁▁▃▃▄▁▄███████▇▇▆▇▆▆▄▅▄▁▄▃▃▅▅▆▆▇▆▅▆▆▆▄▅▆▄ █
98.9 ns Histogram: log(frequency) by time 115 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
overhead due to closure+wrapping for full computations = +41 ns
overhead due to closure+wrapping+logic for full computations = +142 ns
overhead due to closure+wrapping+logic for for computations = -1158 ns (<0 means savings)
overhead due to closure+wrapping+logic for partial computations = -1089 ns (<0 means savings)
As expected, much time is saved when re-evaluations are not needed. When they are needed, the logic adds a small additional penalty.
The code above is the basis for ComputationGraphs.generateComputeFunctions.
Parallel computations
Parallel evaluation are implemented by associating to each computation node one Threads.Task
and one pair of Threads.Events
. For each computation node i
:
- The task
task[i]::Threads.Task
is responsible carrying out the evaluation of nodei
and synchronizing it with the other nodes. - The event
request[i]::Threads.Event(autoreset=true)
is used to requesttask[i]
to evaluate its node, by issuingnotify(request[i])
. - The event
valid[i]::Threads.Event(autoreset=false)
is used by nodei
to notify all other nodes that it has finished handling a computation request received throughrequest[i]
The following protocol is used:
All node tasks are spawn simultaneously and each task
i
immediately waits onrequest[i]
for evaluation request.Upon receiving a request, task
i
checks which of its parents have valid data:- For every parent
p
with missing data, it issues an evaluation request usingnotify(request[p])
. - After that, the task waits on the requests to be fulfilled by using
wait(valid[p])
for the same set of parent node.
- For every parent
Once all parents have valid data, node
i
performs its own computation and notifies any waiting child node that its data became valid usingnotify[valid[i]]
.
The operation described above makes the following assumptions:
Any thread that needs the value of node
i
should first issues an evaluation request usingnotify(request[i])
and then wait for its completion usingwait(valid[i])
.When the value of a variable
v
changes, all its children nodesc
need to be notified that their values become invalid by issuingreset(valid[c])
.To avoid races, these last
reset(valid[c])
cannot be done while computations are being performed.
The last assumption above should be enforced by an explicit locking mechanism, but that has not yet been implemented.
For very large matrix multiplications, BLAS makes good use of multiple threads. In this case, we should not expect significant improvements with respect to evaluating the computation graph sequentially. Instead, it is better to allow BLAS to manage all the threads, with a sequential evaluation of the computation graph.
Parallelism implementation
We will illustrate the mechanism above with the computation of A*x+B*y
for which the two multiplications can be parallelized. The corresponding graph has
- three nodes
A
,x
,B
,y
that corresponds to variables; and - three computation nodes, two for each of the multiplications and the other for the addition.
We start by pre-allocating memory for all nodes
using Random
A = rand(Float64,4000,3000)
x = rand(Float64,3000,1000)
B = rand(Float64,4000,2500)
y = rand(Float64,2500,1000)
Ax = Matrix{Float64}(undef,4000,1000)
By = similar(Ax)
AxBy = similar(Ax)
and defining the computations for each node
using LinearAlgebra
function node_Ax!(out::Matrix{F},in1::Matrix{F}, in2::Matrix{F}) where {F}
mul!(out,in1,in2)
end
function node_By!(out::Matrix{F},in1::Matrix{F}, in2::Matrix{F}) where {F}
mul!(out,in1,in2)
end
function node_AxBy!(out::Matrix{F},in1::Matrix{F}, in2::Matrix{F}) where {F}
@. out = in1 + in2
end
We used fairly big matrices for which the computation takes some time, as we can see below:
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level
@time begin
node_Ax!(Ax,A,x)
node_By!(By,B,y)
node_AxBy!(AxBy,Ax,By)
end
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
0.932706 seconds
To implement the parallelization mechanism described above we need 2 event-triggered objects per node:
valid=Tuple(Threads.Event(false) for _ in 1:3)
request=Tuple(Threads.Event(true) for _ in 1:3)
reset.(valid)
reset.(request)
The computation tasks can then be launched using
tasks=[
Threads. @spawn while true
wait(request[1])
if !valid[1].set
node_Ax!(Ax,A,x) # this node's computation
notify(valid[1])
end
end
Threads. @spawn while true
wait(request[2])
if !valid[2].set
node_By!(By,B,y) # this node's computation
notify(valid[2])
end
end
Threads.@spawn while true
wait(request[3])
if !valid[3].set
valid[1].set || notify(request[1])
valid[2].set || notify(request[2])
valid[1].set || wait(valid[1])
valid[2].set || wait(valid[2])
node_AxBy!(AxBy,Ax,By) # this node's computation
notify(valid[3])
end
end
]
3-element Vector{Task}:
Task (runnable, started) @0x00007f048398f3a0
Task (runnable, started) @0x00007f048398f530
Task (runnable, started) @0x00007f048398f6c0
Very similar code is used in computeSpawn! to parallelize the computation of general graphs.
Parallelism verification
To verify the operation of the approaches outlined above, we make a request for the value of the final node A*x+B*y
and wait on the node being valid:
using ThreadPinning
pinthreads(:cores)
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level
fill!(Ax,0.0)
fill!(By,0.0)
fill!(AxBy,0.0)
reset.(valid)
println("valid before :",getproperty.(valid,:set))
@time begin
notify(request[3])
wait(valid[3])
end
println("valid after :",getproperty.(valid,:set))
@assert Ax==A*x
@assert By==B*y
@assert AxBy==A*x+B*y
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
valid before :(false, false, false)
0.511647 seconds (9 allocations: 416 bytes)
valid after :(true, true, true)
When multiple hardware threads are available, the time reported by @time
is roughly about half, showing a good use of the threads.
We can see whether the julia threads were successfully "pinned" to physical hardware threads using ThreadPinning.threadinfo()
, where red means that multiple julia threads are running on the same hardware thread and purple means that the julia thread is really running on a hyperthread. In either case, we should not expect true parallelism. This is often the case when code is run through a GitHub action (as in generating this manual page) on a computer with a single core with Simultaneous Multithreading (SMT).
using ThreadPinning
@show Threads.nthreads()
@show pinthreads(:cores)
threadinfo()
Threads.nthreads() = 2
pinthreads(:cores) = nothing
Hostname: runnervmf4ws1
CPU(s): 1 x AMD EPYC 7763 64-Core Processor
CPU target: znver3
Cores: 2 (4 CPU-threads due to 2-way SMT)
NUMA domains: 1 (2 cores each)
Julia threads: 2
CPU socket 1
0,1, 2,3
# = Julia thread, # = Julia thread on HT, # = >1 Julia thread
(Mapping: 1 => 0, 2 => 2,)