📜  Julia 中的基准测试

📅  最后修改于: 2021-11-25 04:39:48             🧑  作者: Mango

在 Julia 中,大多数代码都经过了速度效率的检查。 Julia 的标志之一是它比其他科学计算同行( Python、R、Matlab快得多。为了验证这一点,我们经常倾向于比较跨不同语言运行的代码块速度性能。在我们尝试多种方法来解决问题的情况下,有必要决定最有效的方法,在这种情况下,我们显然会选择最快的方法。
在 Julia 中测试代码块的最传统方法之一是使用@time 宏。在 Julia 中,我们说全局对象会降低性能。

Python3
# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Generate random data
x = rand(rng, 1000);
  
# a function that considers
# x as a global variable
function prod_global()
    prod = 0.0
    for i in x
       prod *= i
    end
    return prod
end;
  
# a function that accepts
# x as a local variable
function prod_local(x)
    prod = 0.0
    for i in x
       prod *= i
    end
    return prod
end;


Python3
# first run
@time prod_global()
 
# second run
@time prod_global()


Python3
# first run
@time prod_local(x)
 
# second run
@time prod_local(x)


Python3
# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Generate random data
x = rand(rng, 1000);
  
# a function that considers
# x as a global variable
function sum_global()
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
 
# First we force compile the function
sum_global()
  
# Import profiling library
using Profile
  
# Profile sum_global
@profile sum_global
  
# Print the results
Profile.print()


Python3
# Make sure you run the following code
# in a fresh repl environment
# This will clear results from previous profiling
# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Generate random data
x = rand(rng, 1000);
  
# A function that accepts
# x as a local variable
function sum_local(x)
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
  
# Force compile the function
sum_local(x)
  
# Import the library
using Profile
  
# Profile sum_local()
@profile sum_local(x)
  
# Print the results
Profile.print()


Python3
# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Import the package
using BenchmarkTools
  
# Generate random data
x = rand(rng, 1000);
  
# a function that considers
# x as a global variable
function sum_global()
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
  
# A function that accepts
# x as a local variable
function sum_local(x)
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
 
# Benchmark the sum_global() function
@benchmark sum_global()
  
# Benchmark the sum_local(x) function
@benchmark sum_local(x)


Python3
# @btime for sum_global()
@btime sum_global()
 
# @btime for sum_local(x)
@btime sum_local(x)
 
# @belapsed for sum_global()
@belapsed sum_global()
 
# @belapsed for sum_local(x)
@belapsed sum_local(x)


Python3
# apply custom benchmarks
bg = @benchmark sum_global() seconds=1 time_tolerance=0.01
 
# apply custom benchmarks
bl = @benchmark sum_local(x) seconds=1 time_tolerance=0.01


输出:

现在为了比较这两个函数,我们将使用我们的@time 宏。对于新环境,在第一次调用(@time prod_global())时,prod_global()函数和其他计时所需的函数被编译,因此不应认真对待首次运行的结果。

蟒蛇3

# first run
@time prod_global()
 
# second run
@time prod_global()

输出:

让我们尝试使用本地 x 测试该函数

蟒蛇3

# first run
@time prod_local(x)
 
# second run
@time prod_local(x)

输出:

分析 Julia 代码

对于 Julia 中的分析代码,我们使用@profile 宏。它对正在运行的代码进行测量,并产生输出来帮助开发人员分析每行花费时间。它通常用于识别阻碍性能的代码块/函数中的瓶颈
让我们试着分析我们之前的例子,看看为什么全局变量会阻碍性能!
此外,我们现在将用 sum 替换乘积,以便计算在任何时候都不会趋向于无穷大或零。

蟒蛇3

# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Generate random data
x = rand(rng, 1000);
  
# a function that considers
# x as a global variable
function sum_global()
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
 
# First we force compile the function
sum_global()
  
# Import profiling library
using Profile
  
# Profile sum_global
@profile sum_global
  
# Print the results
Profile.print()

输出:

蟒蛇3

# Make sure you run the following code
# in a fresh repl environment
# This will clear results from previous profiling
# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Generate random data
x = rand(rng, 1000);
  
# A function that accepts
# x as a local variable
function sum_local(x)
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
  
# Force compile the function
sum_local(x)
  
# Import the library
using Profile
  
# Profile sum_local()
@profile sum_local(x)
  
# Print the results
Profile.print()

输出:

您一定想知道我们如何简单地基于@time分析来总结代码的性能,并且许多这样的决定是通过对各种试验的一致分析和随着时间的推移观察代码块的性能而做出的。 Julia 有一个扩展包来运行可靠的基准测试,称为Benchmark Tools.jl

对代码进行基准测试

使用 Benchmark Tools 对代码块进行基准测试的最传统方法之一是@benchmark
考虑上面sum_local(x)sum_global() 的例子:

蟒蛇3

# Import the library
using Random
   
# Using the MersenneTwister rng
# Here 1234 is a seed value
rng = MersenneTwister(1234);
  
# Import the package
using BenchmarkTools
  
# Generate random data
x = rand(rng, 1000);
  
# a function that considers
# x as a global variable
function sum_global()
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
  
# A function that accepts
# x as a local variable
function sum_local(x)
    sum = 0.0
    for i in x
       sum += i
    end
    return sum
end;
 
# Benchmark the sum_global() function
@benchmark sum_global()
  
# Benchmark the sum_local(x) function
@benchmark sum_local(x)

输出:

@benchmark 宏给出了很多细节(内存分配、最短时间、平均时间、中值时间、样本等),对许多开发人员都很有用,但有时我们需要快速的特定参考,例如: @btime 宏在返回表达式值之前打印最短时间和内存分配@belapsed 宏以秒为单位返回最短时间

蟒蛇3

# @btime for sum_global()
@btime sum_global()
 
# @btime for sum_local(x)
@btime sum_local(x)
 
# @belapsed for sum_global()
@belapsed sum_global()
 
# @belapsed for sum_local(x)
@belapsed sum_local(x)

输出:

@benchmark 宏为我们提供了配置基准过程的方法。
您可以将以下关键字参数传递给@benchmark,并运行以配置执行过程:

蟒蛇3

# apply custom benchmarks
bg = @benchmark sum_global() seconds=1 time_tolerance=0.01
 
# apply custom benchmarks
bl = @benchmark sum_local(x) seconds=1 time_tolerance=0.01

输出: