library(tensorflow)
library(keras)
cnn <- function(train, complexity=2, epoch=10, batch_size=50, valid_split=0.3, verbose=0)
{
model <- keras_model_sequential()
## default 2 hidden layer
model %>% layer_conv_2d(filters = 32, kernel_size = c(3,3), activation = "relu",
input_shape = c(32,32,3)) %>%
layer_conv_2d(filters = 64, kernel_size = c(3,3), activation = "relu") %>%
layer_max_pooling_2d(pool_size = c(2,2))
if (complexity > 2) {
for (i in 1: (complexity - 2)) {
model %>%
layer_conv_2d(filters = 64, kernel_size = c(3,3), activation = "relu")
}
# Use max pooling once more
model %>% layer_max_pooling_2d(pool_size = c(2,2))
}
# Flatten max filtered output into feature vector
# and feed into dense layer
model %>%
layer_flatten() %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 10, activation = "softmax")
# summary(model)
model %>% compile(
optimizer = "adam",
loss = "sparse_categorical_crossentropy",
metrics = "accuracy"
)
model %>% fit(
x = train$x, y = train$y,
# x = cifar$train$x, y = cifar$train$y,
batch_size = batch_size,
epochs = epoch,
# validation_data = unname(valid),
validation_split = valid_split,
verbose = verbose
)
# plot(history)
return(model)
}
suppressMessages(library(tidyverse))
suppressMessages(library(ranger))
suppressMessages(library(parallel)) # both ranger and keras run in parallel internally!
#suppressMessages(library(doMC))
#registerDoMC(cores=detectCores()-1)
set.seed(12345)
class_names <- c('airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
print(load(url("http://www.cis.jhu.edu/~parky/RF/Cifar10.RData"))) # for RF, n x 3072, col3073 is lable, col3074 is Description
table(cifar10$Description)
print(load(url("http://www.cis.jhu.edu/~parky/RF/Cifar10.CNN.RData"))) # for CNN, n x 32 x 32 x 3
table(dat.cnn$Y)
K <- 2
ntrain <- 10000
ntest <- 2000
svec <- c(100,166,278,464,774,1292,2154,3594,5994,8000,10000)
# CNN
epoch <- 10
# RF
ntree <- 500 #c(100, 300, 500)
# Pick a pair of classes
c1 <- 5; c2 <- 7
use <- c(c1,c2)
(clab <- class_names[use+1])
(cname <- paste0(as.character(use), collapse=""))
dir.create(cname, showWarnings=FALSE)
dat <- cifar10 %>% dplyr::filter(Label %in% use) %>% dplyr::select(-Label); nrow(dat)
dat$Description <- factor(dat$Description); table(dat$Description)
dim(dat)
train <- head(dat, ntrain)
test <- tail(dat, ntest)
train %>% count(Description)
test %>% count(Description)
## Run RF
#tmp <- foreach (nsamp = svec) %dopar% { # use for CNN!!
for (nsamp in svec) { # use for CNN!!
set.seed(12345 + nsamp)
samp <- sample(nrow(train), nsamp, replace=FALSE)
(rf.time <- system.time(rf <- ranger(Description ~ ., data=train[samp,], num.trees=ntree[1])))
rf.test <- predict(rf, test %>% select(-Description))
(rf.error <- sum(rf.test$pred != test$Description) / nrow(test))
cat(cname, ": ", nsamp, ": rf10.err = ", round(rf10.error,2), ", rf10.time = ", rf10.time[3], "\n")
save(nsamp, ntree, rf.time, rf.error, file=paste0(cname, "/out-rf-",nsamp,"-ntrain",ntrain,"-ntest",ntest,".Rbin"))
}
## Run CNN
ind.y <- which(dat.cnn$Y %in% use)
train.x <- dat.cnn$dat[ind.y,,,][1:ntrain,,,]
train.y <- head(dat.cnn$Y[ind.y], ntrain); table(train.y)
train <- list(x=train.x, y=train.y)
test.x <- dat.cnn$dat[ind.y,,,][(length(ind.y)-ntest+1):length(ind.y),,,]
test.y <- tail(dat.cnn$Y[ind.y], ntest); table(test.y)
test <- list(x=test.x, y=test.y)
for (nsamp in svec) { # use for CNN!!
set.seed(12345 + nsamp)
train.s <- train
test.s <- test
samp <- sample(nrow(train.s$x), nsamp, replace=FALSE)
train.s <- list(x=train.s$x[samp,,,], y=train.s$y[samp]); table(train.s$y)
cnn.fname <- paste0(cname,"/out-cnn-",nsamp,"-ntrain",ntrain,"-ntest",ntest,".Rbin")
cnn3.time <- system.time(model3 <- cnn(train.s, complexity=3, epoch=epoch, verbose=0))
cnn3.error <- 1 - evaluate(model3, test.s$x, test.s$y, verbose = 0)$accuracy
cnn8.time <- system.time(model8 <- cnn(train.s, complexity=8, epoch=epoch, verbose=0))
cnn8.error <- 1 - evaluate(model8, test.s$x, test.s$y, verbose = 0)$accuracy
cat(cname, ": ", nsamp, ": cnn3.err = ", round(cnn3.error,2), ": cnn3.time = ", cnn3.time[3], "\n")
cat(cname, ": ", nsamp, ": cnn8.err = ", round(cnn8.error,2), ": cnn8.time = ", cnn8.time[3], "\n")
cat("----------------------------------------------------------------------------------\n")
save(nsamp, epoch, cnn3.error, cnn8.error, file=cnn.fname)
}