Representation of computation graphs

This section of the manual documents the inner workings of the graph computation functions in the source file src/compute.jl.

Warning

The timing information reported here is obtained by running Documenter's @example triggered by a Github action. Because of this, there is no control over the hardware used and consequently the timing values that appear in this document are not very reliable; especially in what regards Parallel computations.

Computation nodes

An expression like A*x+b contains

  • three nodes A, b, and x that corresponds to variables, and
  • two computation nodes, one for the multiplication and the other for the addition.

Since we are aiming for allocation-free computation, we start by pre-allocating memory for all nodes

using Random
A = rand(Float64,400,30)  # pre-allocated storage for the variable A
x = rand(Float64,30)     # pre-allocated storage for the variable b
b = rand(Float64,400)     # pre-allocated storage for the variable x
Ax = similar(b)          # pre-allocated storage for the computation node A*x
Axb = similar(b)         # pre-allocated storage for the computation node A*x+b

and associate to the following functions to the two computation nodes:

using LinearAlgebra
function node_Ax!(out::Vector{F},in1::Matrix{F}, in2::Vector{F}) where {F}
    mul!(out,in1,in2)
end
function node_Axb!(out::Vector{F},in1::Vector{F}, in2::Vector{F}) where {F}
    @. out = in1 + in2
end

It would be temping to construct the computation graph out of such functions. However, every function in julia has is own unique type (all subtypes of the Function abstract type). This is problematic because we will often need to iterate over the nodes of a graph, e.g., to re-evaluate all nodes in the graph or just the parents of a specific node. If all nodes have a unique type, then such iterations not be type-stable.

To resolve this issue we do two "semantic" transformations to the functions above: function closure and function wrapping with the package FunctionWrappers.

Function closure

Function closure allow us to obtain functions for all the nodes that "look the same" in the following sense:

  • they all have they have the same signature (i.e., same number of input parameters and with the same types), and
  • they all return a value of the same type.

Specifically, we "capture" the input parameters for the two computation nodes, which makes them look like parameter-free functions that return nothing:

@inline node_Ax_closed!() = let  Ax=Ax , A=A, x=x
    node_Ax!(Ax,A,x)
    nothing
    end
@inline node_Axb_closed!() = let Axb=Axb, Ax=Ax, b=b
    node_Axb!(Axb,Ax,b)
    nothing
    end
Note

See Performance tips on the performance of captured variables on the use of let, which essentially helps the parser by "fixing" the captured variable. To be precise "fixing" the arrays, but not the values of their entries.

Function wrapping

Even though all node functions now have similar inputs and outputs, they are still not of the same type (as far as julia is concerned). To fix this issue, we use the package FunctionWrappers to create a type-stable wrapper:

import ComputationGraphs
node_Ax_wrapped = ComputationGraphs.FunctionWrapper(node_Ax_closed!)
node_Axb_wrapped = ComputationGraphs.FunctionWrapper(node_Axb_closed!)

The "wrapped" functions can be called directly with:

begin # hide
node_Ax_wrapped()
node_Axb_wrapped()
nothing # hide
end # hide

or a little faster with

ComputationGraphs.do_ccall(node_Ax_wrapped)
ComputationGraphs.do_ccall(node_Axb_wrapped)
Warning

The code above does not actually use FunctionWrappers; instead it uses a very simplified version of FunctionWrappers that can only wrap functions with no arguments that always return nothing.

To use FunctionWrappers, we would have used instead

import FunctionWrappers
node_Ax_wrapped_FW = FunctionWrappers.FunctionWrapper{Nothing,Tuple{}}(node_Ax_closed!)
node_Axb_wrapped_FW = FunctionWrappers.FunctionWrapper{Nothing,Tuple{}}(node_Axb_closed!)

and the functions would be called with

begin # hide
FunctionWrappers.do_ccall(node_Ax_wrapped_FW, ())
FunctionWrappers.do_ccall(node_Axb_wrapped_FW, ())
nothing # hide
end # hide

Verification

We can now check the fruits of our work.

  • Type stability?
println("Type stability for original: ", typeof(node_Ax!)==typeof(node_Axb!))
println("Type stability for wrapped : ", typeof(node_Ax_wrapped)==typeof(node_Axb_wrapped))
Type stability for original: false
Type stability for wrapped : true
  • Correctness?
rand!(A)
rand!(b)
rand!(x)

# the original functions
node_Ax!(Ax,A,x)
node_Axb!(Axb,Ax,b)
println("Correctness for original: ", Axb==(A*x+b))

rand!(A)
rand!(b)
rand!(x)

# the new functions
ComputationGraphs.do_ccall(node_Ax_wrapped)
ComputationGraphs.do_ccall(node_Axb_wrapped)
println("Correctness for wrapped : ", Axb==(A*x+b))
Correctness for original: true
Correctness for wrapped : true
  • Speed?
using BenchmarkTools, Printf
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level

bmk1 = @benchmark begin
    node_Ax!($Ax,$A,$x)
    node_Axb!($Axb,$Ax,$b)
end evals=1000 samples=10000

bmk3 = @benchmark begin
    node_Ax_closed!()
    node_Axb_closed!()
end evals=1000 samples=10000

bmk2 = @benchmark begin
    ComputationGraphs.do_ccall($node_Ax_wrapped)
    ComputationGraphs.do_ccall($node_Axb_wrapped)
end evals=1000 samples=10000

@printf("Overhead due to closure  = %3.f ns\n",median(bmk3.times)-median(bmk1.times))
@printf("Overhead due to wrapping = %3.f ns\n",median(bmk2.times)-median(bmk3.times))
@printf("Total overhead           = %3.f ns\n",median(bmk2.times)-median(bmk1.times))
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
Original:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
 Range (minmax):  1.162 μs 2.461 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     1.195 μs               GC (median):    0.00%
 Time  (mean ± σ):   1.196 μs ± 43.078 ns   GC (mean ± σ):  0.00% ± 0.00%

                        ▃▇█▁                               
  ▁▂▄▅▅▅▄▃▃▃▃▃▄▃▃▃▃▃▂▂▂▇████▇▇▇█▆▆▄▅▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  1.16 μs        Histogram: frequency by time        1.24 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
Closure:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
 Range (minmax):  1.197 μs 2.144 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     1.223 μs               GC (median):    0.00%
 Time  (mean ± σ):   1.219 μs ± 18.150 ns   GC (mean ± σ):  0.00% ± 0.00%

  ▄▇▄▄▃▁ ▁▃▄▃▂▂▃▂▁▂▁▁▁▁ ▁▁   ▁█▅▅▃▂▂▂▆▅▃▃▃▃▂▁    ▁▁▁        ▂
  █████████████████████████▇████████████████▇█▇████▇▇▆▆▅▅▇ █
  1.2 μs       Histogram: log(frequency) by time     1.25 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
Wrapped:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
 Range (minmax):  1.207 μs 2.489 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     1.237 μs               GC (median):    0.00%
 Time  (mean ± σ):   1.238 μs ± 36.247 ns   GC (mean ± σ):  0.00% ± 0.00%

                          ▁▄                                
  ▁▄▆▃▂▁▁▂▃▂▂▂▂▂▁▁▁▁▂▂▁▁▂███▄▃▄▄▇▅▄▄▃▂▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  1.21 μs        Histogram: frequency by time        1.27 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
Overhead due to closure  =  28 ns
Overhead due to wrapping =  15 ns
Total overhead           =  43 ns

This shows that closure and wrapping do introduce a small overhead (tens of ns). However, the benefits of type stability will appear when we start iterating over nodes. To see this consider the following function that evaluates a set of nodes:

function compute_all!(nodes::Vector{Function})
    for node in nodes
        node()
    end
end
function compute_all_wrapped!(nodes::Vector{ComputationGraphs.FunctionWrapper})
    for node::ComputationGraphs.FunctionWrapper in nodes
        ComputationGraphs.do_ccall(node)
    end
end

We can use @code_warntype to see how wrapping helps in terms of type stability:

# using just closure
nodes_closed=repeat([node_Ax_closed!,node_Axb_closed!],outer=5)
@show typeof(nodes_closed)
InteractiveUtils.@code_warntype compute_all!(nodes_closed)

# using closure+wrapped
nodes_wrapped=repeat([node_Ax_wrapped,node_Axb_wrapped],outer=5)
@show typeof(nodes_wrapped)
InteractiveUtils.@code_warntype compute_all_wrapped!(nodes_wrapped)
typeof(nodes_closed) = Vector{Function}
MethodInstance for Main.compute_all!(::Vector{Function})
  from compute_all!(nodes::Vector{Function}) @ Main lib_representation.md:229
Arguments
  #self#::Core.Const(Main.compute_all!)
  nodes::Vector{Function}
Locals
  @_3::Union{Nothing, Tuple{Function, Int64}}
  node::Function
Body::Nothing
1 ─ %1  = nodes::Vector{Function}
       (@_3 = Base.iterate(%1))
 %3  = @_3::Union{Nothing, Tuple{Function, Int64}}
 %4  = (%3 === nothing)::Bool
 %5  = Base.not_int(%4)::Bool
└──       goto #4 if not %5
2 ┄ %7  = @_3::Tuple{Function, Int64}
       (node = Core.getfield(%7, 1))
 %9  = Core.getfield(%7, 2)::Int64
 %10 = node::Function
       (%10)()
       (@_3 = Base.iterate(%1, %9))
 %13 = @_3::Union{Nothing, Tuple{Function, Int64}}
 %14 = (%13 === nothing)::Bool
 %15 = Base.not_int(%14)::Bool
└──       goto #4 if not %15
3 ─       goto #2
4 ┄       return nothing

typeof(nodes_wrapped) = Vector{ComputationGraphs.FunctionWrapper}
MethodInstance for Main.compute_all_wrapped!(::Vector{ComputationGraphs.FunctionWrapper})
  from compute_all_wrapped!(nodes::Vector{ComputationGraphs.FunctionWrapper}) @ Main lib_representation.md:234
Arguments
  #self#::Core.Const(Main.compute_all_wrapped!)
  nodes::Vector{ComputationGraphs.FunctionWrapper}
Locals
  @_3::Union{Nothing, Tuple{ComputationGraphs.FunctionWrapper, Int64}}
  node::ComputationGraphs.FunctionWrapper
  @_5::ComputationGraphs.FunctionWrapper
Body::Nothing
1 ─ %1  = nodes::Vector{ComputationGraphs.FunctionWrapper}
       (@_3 = Base.iterate(%1))
 %3  = @_3::Union{Nothing, Tuple{ComputationGraphs.FunctionWrapper, Int64}}
 %4  = (%3 === nothing)::Bool
 %5  = Base.not_int(%4)::Bool
└──       goto #7 if not %5
2 ┄       Core.NewvarNode(:(node))
 %8  = @_3::Tuple{ComputationGraphs.FunctionWrapper, Int64}
 %9  = Core.getfield(%8, 1)::ComputationGraphs.FunctionWrapper
 %10 = ComputationGraphs.FunctionWrapper::Core.Const(ComputationGraphs.FunctionWrapper)
       (@_5 = %9)
 %12 = @_5::ComputationGraphs.FunctionWrapper
 %13 = (%12 isa %10)::Core.Const(true)
└──       goto #4 if not %13
3 ─       goto #5
4 ─       Core.Const(:(@_5))
       Core.Const(:(Base.convert(%10, %16)))
└──       Core.Const(:(@_5 = Core.typeassert(%17, %10)))
5 ┄ %19 = @_5::ComputationGraphs.FunctionWrapper
       (node = %19)
 %21 = Core.getfield(%8, 2)::Int64
 %22 = ComputationGraphs.do_ccall::Core.Const(ComputationGraphs.do_ccall)
 %23 = node::ComputationGraphs.FunctionWrapper
       (%22)(%23)
       (@_3 = Base.iterate(%1, %21))
 %26 = @_3::Union{Nothing, Tuple{ComputationGraphs.FunctionWrapper, Int64}}
 %27 = (%26 === nothing)::Bool
 %28 = Base.not_int(%27)::Bool
└──       goto #7 if not %28
6 ─       goto #2
7 ┄       return nothing

These specific functions compute_all! and compute_all_wrapped! are so simple that type instability actually does not lead to heap allocations, but the use of wrapped functions still leads to slightly faster code.

@show typeof(nodes_closed)
bmk3 = @benchmark compute_all!($nodes_closed) evals=1 samples=10000

@show typeof(nodes_wrapped)
bmk2 = @benchmark compute_all_wrapped!($nodes_wrapped)  evals=1 samples=10000
typeof(nodes_closed) = Vector{Function}
Closure:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (minmax):  6.231 μs 29.465 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     6.321 μs                GC (median):    0.00%
 Time  (mean ± σ):   6.431 μs ± 880.259 ns   GC (mean ± σ):  0.00% ± 0.00%

  ▇          ▂                                              ▁
  █▅▄▃▁▃▄▅▆▃█▅▁▁▃▁▁▄▁▃▄▄▁▁▁▁▁▁▁▃▁▅▅▄▃▃▅▃▁▄▅▃▁▅▃▄▁▃▁▁▁▁▁▁▁▇ █
  6.23 μs      Histogram: log(frequency) by time      10.8 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
typeof(nodes_wrapped) = Vector{ComputationGraphs.FunctionWrapper}
Closure+Wrapping:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (minmax):  5.951 μs 33.312 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     6.152 μs                GC (median):    0.00%
 Time  (mean ± σ):   6.270 μs ± 941.613 ns   GC (mean ± σ):  0.00% ± 0.00%

                                                             
  ▅▄█▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
  5.95 μs         Histogram: frequency by time        9.82 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

Conditional computations

So far we discussed how to compute all nodes or some give vector of nodes. Restricting evaluations to just the set of nodes that need to be recomputed requires introducing some simple logic to the function closures.

Implementation

To support need-based evaluations, we use a BitVector to keep track of which nodes have been evaluated. For our 2-node example, we would use

validValue=falses(2)

The functions below now include the logic for need-based evaluation:

node_Ax_conditional_closed() = let validValue=validValue,
    Ax=Ax , A=A, x=x
    node_Ax!(Ax,A,x)    # this node's computation
    nothing
    end
node_Ax_conditional_wrapped = ComputationGraphs.FunctionWrapper(node_Ax_conditional_closed)
node_Axb_conditional_closed() = let validValue=validValue,
    Axb=Axb, Ax=Ax, b=b,
    node_Ax_conditional_wrapped=node_Ax_conditional_wrapped
    # compute parent node Ax (if needed)
    if !validValue[1]
        validValue[1]=true
         ComputationGraphs.do_ccall(node_Ax_conditional_wrapped)
    end
    node_Axb!(Axb,Ax,b)  # this nodes' computation
    nothing
    end
node_Axb_conditional_wrapped = ComputationGraphs.FunctionWrapper(node_Axb_conditional_closed)

With this logic, we only need a call to evaluate the node A*x+b, as this will automatically trigger the evaluation of A*x (if needed). To check that the logic is working, we do:

begin # hide
fill!(validValue,false)
fill!(Ax,0.0)
fill!(Axb,0.0)
ComputationGraphs.do_ccall(node_Ax_conditional_wrapped)
@assert validValue == [false,false] "no parent computed"
@assert all(Ax .== A*x)  "should only compute Ax"
@assert all(Axb .== 0) "should only compute Ax"

fill!(validValue,false)
fill!(Ax,0.0)
fill!(Axb,0.0)
ComputationGraphs.do_ccall(node_Axb_conditional_wrapped)
@assert validValue == [true,false] "parent should have been computed"
@assert all(Ax .== A*x)  "should compute both"
@assert all(Axb .== A*x+b) "should compute both"
nothing # hide
end # hide

Timing verification

We can now check the impact of the new logic on timing.

using BenchmarkTools, Printf
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level

bmk1 = @benchmark begin
    node_Ax!($Ax,$A,$x)
    node_Axb!($Axb,$Ax,$b)
end evals=1000 samples=10000

bmk2a = @benchmark begin
    ComputationGraphs.do_ccall($node_Ax_wrapped)
    ComputationGraphs.do_ccall($node_Axb_wrapped)
end evals=1000 samples=10000

bmk2b = @benchmark begin
    $validValue[1]=false
    $validValue[2]=false
    if !$validValue[2]
        $validValue[2]=true
        ComputationGraphs.do_ccall($node_Axb_conditional_wrapped)
    end
end evals=1000 samples=10000

bmk3 = @benchmark begin
    if !$validValue[2]
        $validValue[2]=true
        ComputationGraphs.do_ccall($node_Axb_conditional_wrapped)
    end
end evals=1 samples=10000

bmk4 = @benchmark begin
    $validValue[2]=false
    if !$validValue[2]
        $validValue[2]=true
        ComputationGraphs.do_ccall($node_Axb_conditional_wrapped)
    end
end evals=1000 samples=10000

@printf("overhead due to closure+wrapping for full computations          = %+6.f ns\n",
    median(bmk2a.times)-median(bmk1.times))
@printf("overhead due to closure+wrapping+logic for full computations    = %+6.f ns\n",
    median(bmk2b.times)-median(bmk1.times))
# @printf("overhead due just to             logic for full computations    = %+6.f ns\n", # hide
@printf("overhead due to closure+wrapping+logic for for computations     = %+6.f ns (<0 means savings)\n",
    median(bmk3.times)-median(bmk1.times))
@printf("overhead due to closure+wrapping+logic for partial computations = %+6.f ns (<0 means savings)\n",
    median(bmk4.times)-median(bmk1.times))
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
Unconditional computation:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
 Range (minmax):  1.152 μs 1.846 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     1.188 μs               GC (median):    0.00%
 Time  (mean ± σ):   1.188 μs ± 17.270 ns   GC (mean ± σ):  0.00% ± 0.00%

                            █▇▇▄▃▄▁                         
  ▂▃▃▃▃▄▃▂▁▁▂▂▃▂▂▂▁▁▂▂▁▁▁▂▄▅███████▅▅▄▅▇▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  1.15 μs        Histogram: frequency by time        1.22 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
Unconditional computation with wrapping:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (minmax):  5.951 μs 33.312 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     6.152 μs                GC (median):    0.00%
 Time  (mean ± σ):   6.270 μs ± 941.613 ns   GC (mean ± σ):  0.00% ± 0.00%

                                                             
  ▅▄█▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
  5.95 μs         Histogram: frequency by time        9.82 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computation, but with all valid=false:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (minmax):  5.951 μs 33.312 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     6.152 μs                GC (median):    0.00%
 Time  (mean ± σ):   6.270 μs ± 941.613 ns   GC (mean ± σ):  0.00% ± 0.00%

                                                             
  ▅▄█▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
  5.95 μs         Histogram: frequency by time        9.82 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computation, with full reuse:
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (minmax):  29.000 ns 9.678 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     30.000 ns               GC (median):    0.00%
 Time  (mean ± σ):   32.357 ns ± 96.627 ns   GC (mean ± σ):  0.00% ± 0.00%

      █                                                       
  ▃▁▁▁█▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▂ ▂
  29 ns           Histogram: frequency by time          41 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
Conditional computation, with valid=false only for Axb:
BenchmarkTools.Trial: 10000 samples with 1000 evaluations per sample.
 Range (minmax):   98.883 ns181.697 ns   GC (min … max): 0.00% … 0.00%
 Time  (median):      99.705 ns                GC (median):    0.00%
 Time  (mean ± σ):   100.838 ns ±   3.906 ns   GC (mean ± σ):  0.00% ± 0.00%

  ▂▆█▆▃▂                     ▂▄▃▂▁▁                            ▂
  ██████▆▇█▇▇▇▅▅▆▁▃▃▁▁▃▃▄▁▄███████▇▇▆▇▆▆▄▅▄▁▄▃▃▅▅▆▆▇▆▅▆▆▆▄▅▆▄ █
  98.9 ns       Histogram: log(frequency) by time        115 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
overhead due to closure+wrapping for full computations          =    +41 ns
overhead due to closure+wrapping+logic for full computations    =   +142 ns
overhead due to closure+wrapping+logic for for computations     =  -1158 ns (<0 means savings)
overhead due to closure+wrapping+logic for partial computations =  -1089 ns (<0 means savings)

As expected, much time is saved when re-evaluations are not needed. When they are needed, the logic adds a small additional penalty.

Note

The code above is the basis for ComputationGraphs.generateComputeFunctions.

Parallel computations

Parallel evaluation are implemented by associating to each computation node one Threads.Task and one pair of Threads.Events. For each computation node i:

  • The task task[i]::Threads.Task is responsible carrying out the evaluation of node i and synchronizing it with the other nodes.
  • The event request[i]::Threads.Event(autoreset=true) is used to request task[i] to evaluate its node, by issuing notify(request[i]).
  • The event valid[i]::Threads.Event(autoreset=false) is used by node i to notify all other nodes that it has finished handling a computation request received through request[i]

The following protocol is used:

  • All node tasks are spawn simultaneously and each task i immediately waits on request[i] for evaluation request.

  • Upon receiving a request, task i checks which of its parents have valid data:

    1. For every parent p with missing data, it issues an evaluation request using notify(request[p]).
    2. After that, the task waits on the requests to be fulfilled by using wait(valid[p]) for the same set of parent node.
  • Once all parents have valid data, node i performs its own computation and notifies any waiting child node that its data became valid using notify[valid[i]].

The operation described above makes the following assumptions:

  • Any thread that needs the value of node i should first issues an evaluation request using notify(request[i]) and then wait for its completion using wait(valid[i]).

  • When the value of a variable v changes, all its children nodes c need to be notified that their values become invalid by issuing reset(valid[c]).

  • To avoid races, these last reset(valid[c]) cannot be done while computations are being performed.

Warning

The last assumption above should be enforced by an explicit locking mechanism, but that has not yet been implemented.

Warning

For very large matrix multiplications, BLAS makes good use of multiple threads. In this case, we should not expect significant improvements with respect to evaluating the computation graph sequentially. Instead, it is better to allow BLAS to manage all the threads, with a sequential evaluation of the computation graph.

Parallelism implementation

We will illustrate the mechanism above with the computation of A*x+B*y for which the two multiplications can be parallelized. The corresponding graph has

  • three nodes A,x,B,y that corresponds to variables; and
  • three computation nodes, two for each of the multiplications and the other for the addition.

We start by pre-allocating memory for all nodes

using Random
A = rand(Float64,4000,3000)
x = rand(Float64,3000,1000)
B = rand(Float64,4000,2500)
y = rand(Float64,2500,1000)
Ax = Matrix{Float64}(undef,4000,1000)
By = similar(Ax)
AxBy = similar(Ax)

and defining the computations for each node

using LinearAlgebra
function node_Ax!(out::Matrix{F},in1::Matrix{F}, in2::Matrix{F}) where {F}
    mul!(out,in1,in2)
end
function node_By!(out::Matrix{F},in1::Matrix{F}, in2::Matrix{F}) where {F}
    mul!(out,in1,in2)
end
function node_AxBy!(out::Matrix{F},in1::Matrix{F}, in2::Matrix{F}) where {F}
    @. out = in1 + in2
end

We used fairly big matrices for which the computation takes some time, as we can see below:

@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level

@time begin
    node_Ax!(Ax,A,x)
    node_By!(By,B,y)
    node_AxBy!(AxBy,Ax,By)
end
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
  0.932706 seconds

To implement the parallelization mechanism described above we need 2 event-triggered objects per node:

valid=Tuple(Threads.Event(false) for _ in 1:3)
request=Tuple(Threads.Event(true) for _ in 1:3)
reset.(valid)
reset.(request)

The computation tasks can then be launched using

tasks=[
    Threads. @spawn while true
        wait(request[1])
        if !valid[1].set
            node_Ax!(Ax,A,x)    # this node's computation
            notify(valid[1])
        end
    end

    Threads. @spawn while true
        wait(request[2])
        if !valid[2].set
            node_By!(By,B,y)    # this node's computation
            notify(valid[2])
        end
    end

    Threads.@spawn while true
        wait(request[3])
        if !valid[3].set
            valid[1].set || notify(request[1])
            valid[2].set || notify(request[2])
            valid[1].set || wait(valid[1])
            valid[2].set || wait(valid[2])
            node_AxBy!(AxBy,Ax,By)  # this node's computation
            notify(valid[3])
        end
    end
]
3-element Vector{Task}:
 Task (runnable, started) @0x00007f048398f3a0
 Task (runnable, started) @0x00007f048398f530
 Task (runnable, started) @0x00007f048398f6c0
Note

Very similar code is used in computeSpawn! to parallelize the computation of general graphs.

Parallelism verification

To verify the operation of the approaches outlined above, we make a request for the value of the final node A*x+B*y and wait on the node being valid:

using ThreadPinning
pinthreads(:cores)
@show Threads.nthreads()
BLAS.set_num_threads(1)
@show BLAS.get_num_threads()
@show Base.JLOptions().opt_level

fill!(Ax,0.0)
fill!(By,0.0)
fill!(AxBy,0.0)
reset.(valid)
println("valid before :",getproperty.(valid,:set))
@time begin
    notify(request[3])
    wait(valid[3])
end
println("valid after  :",getproperty.(valid,:set))
@assert Ax==A*x
@assert By==B*y
@assert AxBy==A*x+B*y
Threads.nthreads() = 2
BLAS.get_num_threads() = 1
(Base.JLOptions()).opt_level = 2
valid before :(false, false, false)
  0.511647 seconds (9 allocations: 416 bytes)
valid after  :(true, true, true)

When multiple hardware threads are available, the time reported by @time is roughly about half, showing a good use of the threads.

Note

We can see whether the julia threads were successfully "pinned" to physical hardware threads using ThreadPinning.threadinfo(), where red means that multiple julia threads are running on the same hardware thread and purple means that the julia thread is really running on a hyperthread. In either case, we should not expect true parallelism. This is often the case when code is run through a GitHub action (as in generating this manual page) on a computer with a single core with Simultaneous Multithreading (SMT).

using ThreadPinning
@show Threads.nthreads()
@show pinthreads(:cores)
threadinfo()
Threads.nthreads() = 2
pinthreads(:cores) = nothing
Hostname: 	runnervmf4ws1
CPU(s): 	1 x AMD EPYC 7763 64-Core Processor
CPU target: 	znver3
Cores: 		2 (4 CPU-threads due to 2-way SMT)
NUMA domains: 	1 (2 cores each)

Julia threads: 	2

CPU socket 1
  0,1, 2,3


# = Julia thread, # = Julia thread on HT, # = >1 Julia thread

(Mapping: 1 => 0, 2 => 2,)