R2OpenBUGSの並行計算例

Parallel R2OpenBUGS Example ここから並行計算のコードを紹介します。普通に計算すると8分ぐらいかかりましたが、平行でやると2分弱かかった。75%時間の短縮に成功しました。ちなみに、このリンクを参考にしました。

Parallel R2OpenBUGS Example

Greg Nishihara

2013 January 28

This is an example of how I use OpenBUGS with R2OpenBUGS during my daily work. I make artificial data to illustrate the technique to run R2OpenBUGS in parallel.

First load the libraries needed.

library(snow)
library(snowfall)
library(R2OpenBUGS)

First create a working directory if it doesn't exist.

WORKDIR = "~/Data/R2OpenBUGS_Parallel_Example"
if (!file.exists(WORKDIR)) {
    dir.create(WORKDIR)
}

Next set the number of chains, iterations, and burnins to run.

NCHAINS = 4
NITER = 40000
NBURNIN = 10000
NTHIN = 1

Next, create separate directories for each chain.

chain.directories = paste(WORKDIR, "/chain", 1:NCHAINS, sep = "")
if (sum(file.exists(chain.directories)) == 0) {
    sapply(chain.directories, dir.create)
}

Make 1000 fake datapoints from a normal distribution with a mean of 100 and a standard deviation of 10.

fake.data = rnorm(1000, 100, 10)
model = function() {
  standard_error ~ dunif(0.001,1000)
  precision <- pow(standard_error, -2)
  mean ~ dnorm(0, 0.001)

  for(i in 1:number_of_observations) {
    values[i] ~ dnorm(mean, precision)
  }
}

For each chain, save the model to its directory.

model.file = paste(chain.directories, "/modelfile.txt", sep = "")
sapply(model.file, write.model, model = model)
## $`~/Data/R2OpenBUGS_Parallel_Example/chain1/modelfile.txt`
## NULL
## 
## $`~/Data/R2OpenBUGS_Parallel_Example/chain2/modelfile.txt`
## NULL
## 
## $`~/Data/R2OpenBUGS_Parallel_Example/chain3/modelfile.txt`
## NULL
## 
## $`~/Data/R2OpenBUGS_Parallel_Example/chain4/modelfile.txt`
## NULL

Setup the data for OpenBUGS

dataset = list(number_of_observations = length(fake.data), values = fake.data)

Setup the parameters to monitor

monitor = list("mean", "standard_error")
parallel.bugs = function(chain, WORKDIR, dataset, monitor, NITER, NBURNIN, NTHIN) {

    sub.folder = paste(WORKDIR, "/chain", chain, sep = "")

    inits = function() {
        list(standard_error = runif(1, 0, 100), mean = rnorm(1, 0, 0))
    }

    bugs(data = dataset, inits = inits, parameters.to.save = monitor, n.iter = NITER, 
        n.chains = 1, n.burnin = NBURNIN, n.thin = NTHIN, model.file = "modelfile.txt", 
        codaPkg = FALSE, working.directory = sub.folder)
}

Run a serial version of the code.

write.model(model, con = paste(WORKDIR, "/modelfile.txt", sep = ""))
start.time = Sys.time()
inits = function() {
    list(standard_error = runif(1, 0, 100), mean = rnorm(1, 0, 0))
}
bugs(data = dataset, inits = inits, parameters.to.save = monitor, n.iter = NITER, 
    n.chains = NCHAINS, n.burnin = NBURNIN, n.thin = NTHIN, model.file = "modelfile.txt", 
    codaPkg = FALSE, working.directory = WORKDIR)
## Inference for Bugs model at "modelfile.txt", 
## Current: 4 chains, each with 40000 iterations (first 10000 discarded)
## Cumulative: n.sims = 120000 iterations saved
##                  mean  sd   2.5%    25%    50%    75%  97.5% Rhat  n.eff
## mean             99.5 0.3   98.9   99.3   99.6   99.8  100.2    1 120000
## standard_error   10.1 0.2    9.7   10.0   10.2   10.3   10.6    1 100000
## deviance       7471.4 2.0 7469.0 7470.0 7471.0 7472.0 7477.0    1  95000
## 
## For each parameter, n.eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## 
## DIC info (using the rule, pD = var(deviance)/2)
## pD = 2.0 and DIC = 7473.4
## DIC is an estimate of expected predictive error (lower deviance is better).
finish.serial.time = Sys.time() - start.time

Run the parallel version of the code.

start.time = Sys.time()
sfInit(parallel = TRUE, cpus = NCHAINS)
## Warning: Unknown option on commandline: options(encoding
## R Version:  R version 2.15.2 (2012-10-26)
## snowfall 1.84 initialized (using snow 0.3-10): parallel execution on 4
## CPUs.
sfLibrary(R2OpenBUGS)
## Library R2OpenBUGS loaded.
## Library R2OpenBUGS loaded in cluster.
## Warning: 'keep.source' is deprecated and will be ignored
sfLapply(1:NCHAINS, fun = parallel.bugs, dataset = dataset, monitor = monitor, 
    WORKDIR = WORKDIR, NITER = NITER, NBURNIN = NBURNIN, NTHIN = NTHIN)
## [[1]]
## Inference for Bugs model at "modelfile.txt", 
## Current: 1 chains, each with 40000 iterations (first 10000 discarded)
## Cumulative: n.sims = 30000 iterations saved
##                  mean  sd   2.5%    25%    50%    75%  97.5%
## mean             99.5 0.3   98.9   99.3   99.6   99.8  100.2
## standard_error   10.2 0.2    9.7   10.0   10.2   10.3   10.6
## deviance       7471.4 2.0 7469.0 7470.0 7471.0 7472.0 7477.0
## 
## DIC info (using the rule, pD = var(deviance)/2)
## pD = 2.0 and DIC = 7473.4
## DIC is an estimate of expected predictive error (lower deviance is better).
## 
## [[2]]
## Inference for Bugs model at "modelfile.txt", 
## Current: 1 chains, each with 40000 iterations (first 10000 discarded)
## Cumulative: n.sims = 30000 iterations saved
##                  mean  sd   2.5%    25%    50%    75%  97.5%
## mean             99.5 0.3   98.9   99.3   99.6   99.8  100.2
## standard_error   10.2 0.2    9.7   10.0   10.2   10.3   10.6
## deviance       7471.4 2.0 7469.0 7470.0 7471.0 7472.0 7477.0
## 
## DIC info (using the rule, pD = var(deviance)/2)
## pD = 2.0 and DIC = 7473.4
## DIC is an estimate of expected predictive error (lower deviance is better).
## 
## [[3]]
## Inference for Bugs model at "modelfile.txt", 
## Current: 1 chains, each with 40000 iterations (first 10000 discarded)
## Cumulative: n.sims = 30000 iterations saved
##                  mean  sd   2.5%    25%    50%    75%  97.5%
## mean             99.5 0.3   98.9   99.3   99.6   99.8  100.2
## standard_error   10.2 0.2    9.7   10.0   10.2   10.3   10.6
## deviance       7471.4 2.0 7469.0 7470.0 7471.0 7472.0 7477.0
## 
## DIC info (using the rule, pD = var(deviance)/2)
## pD = 2.0 and DIC = 7473.4
## DIC is an estimate of expected predictive error (lower deviance is better).
## 
## [[4]]
## Inference for Bugs model at "modelfile.txt", 
## Current: 1 chains, each with 40000 iterations (first 10000 discarded)
## Cumulative: n.sims = 30000 iterations saved
##                  mean  sd   2.5%    25%    50%    75%  97.5%
## mean             99.5 0.3   98.9   99.3   99.6   99.8  100.2
## standard_error   10.2 0.2    9.7   10.0   10.2   10.3   10.6
## deviance       7471.4 2.0 7469.0 7470.0 7471.0 7472.0 7477.0
## 
## DIC info (using the rule, pD = var(deviance)/2)
## pD = 2.0 and DIC = 7473.4
## DIC is an estimate of expected predictive error (lower deviance is better).
sfStop()
## Stopping cluster
finish.parallel.time = Sys.time() - start.time

The serial version of the code took 7.87 minutes and the parallel version of the code took 2.11 minutes.

Load the monitored data

CODAchains = paste(chain.directories, "/CODAchain1.txt", sep = "")
output = read.bugs(CODAchains)
## Abstracting deviance ... 30000 valid values
## Abstracting mean ... 30000 valid values
## Abstracting standard_error ... 30000 valid values
## Abstracting deviance ... 30000 valid values
## Abstracting mean ... 30000 valid values
## Abstracting standard_error ... 30000 valid values
## Abstracting deviance ... 30000 valid values
## Abstracting mean ... 30000 valid values
## Abstracting standard_error ... 30000 valid values
## Abstracting deviance ... 30000 valid values
## Abstracting mean ... 30000 valid values
## Abstracting standard_error ... 30000 valid values

Run the Gelman-Rubin diagnostics on the chains. This requires the coda library

library(coda)
## Loading required package: lattice
gelman.diag(output)
## Potential scale reduction factors:
## 
##                Point est. Upper C.I.
## deviance                1          1
## mean                    1          1
## standard_error          1          1
## 
## Multivariate psrf
## 
## 1
gelman.plot(output)

plot of chunk gelman_rubin_diagnostics

Examine the traceplots and the density plots

plot(output[, "mean"], main = "Mean")

plot of chunk plot_results

plot(output[, "standard_error"], main = "Standard error")

plot of chunk plot_results

Examine the autocorrelation in the chains

acf(unlist(output[, "mean"])[1:niter(output)], main = "Mean")

plot of chunk autocorrelationplots

acf(unlist(output[, "standard_error"])[1:niter(output)], main = "Standard error")

plot of chunk autocorrelationplots

コメント

このブログの人気の投稿

RStudioとggplot():プロットができないとき

光合成関連の単語

大村湾調査!! ~海藻・生き物編~