library(tensorflow)
library(keras)

cnn <- function(train, complexity=2, epoch=10, batch_size=50, valid_split=0.3, verbose=0)
{
    model <- keras_model_sequential()
    
    ## default 2 hidden layer
    model %>% layer_conv_2d(filters = 32, kernel_size = c(3,3), activation = "relu", 
                             input_shape = c(32,32,3)) %>% 
                layer_conv_2d(filters = 64, kernel_size = c(3,3), activation = "relu") %>%
                layer_max_pooling_2d(pool_size = c(2,2))
    
    if (complexity > 2) {
        for (i in 1: (complexity - 2)) {
            model %>% 
                    layer_conv_2d(filters = 64, kernel_size = c(3,3), activation = "relu")
        }
        
        # Use max pooling once more
        model %>% layer_max_pooling_2d(pool_size = c(2,2))
    }

    # Flatten max filtered output into feature vector 
    # and feed into dense layer
    model %>% 
              layer_flatten() %>% 
              layer_dense(units = 64, activation = "relu") %>% 
              layer_dense(units = 10, activation = "softmax")
#   summary(model)

    model %>% compile(
              optimizer = "adam",
              loss = "sparse_categorical_crossentropy",
              metrics = "accuracy"
    )

    model %>% fit(
                x = train$x, y = train$y,
#               x = cifar$train$x, y = cifar$train$y,
                batch_size = batch_size, 
                epochs = epoch,
#               validation_data = unname(valid),
                validation_split = valid_split,
                verbose = verbose
      )

#   plot(history)
    return(model)
}
suppressMessages(library(tidyverse))
suppressMessages(library(ranger))
suppressMessages(library(parallel)) # both ranger and keras run in parallel internally!
#suppressMessages(library(doMC))
#registerDoMC(cores=detectCores()-1)


set.seed(12345)

class_names <- c('airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

print(load(url("http://www.cis.jhu.edu/~parky/RF/Cifar10.RData"))) # for RF, n x 3072, col3073 is lable, col3074 is Description
table(cifar10$Description) 

print(load(url("http://www.cis.jhu.edu/~parky/RF/Cifar10.CNN.RData"))) # for CNN, n x 32 x 32 x 3
table(dat.cnn$Y)

K <- 2
ntrain <- 10000 
ntest <- 2000 
svec <- c(100,166,278,464,774,1292,2154,3594,5994,8000,10000)

# CNN
epoch <- 10
# RF
ntree <- 500 #c(100, 300, 500) 

# Pick a pair of classes
c1 <- 5; c2 <- 7
use <- c(c1,c2)
(clab <- class_names[use+1])
(cname <- paste0(as.character(use), collapse=""))
dir.create(cname, showWarnings=FALSE)

dat <- cifar10 %>% dplyr::filter(Label %in% use) %>% dplyr::select(-Label); nrow(dat)
dat$Description <- factor(dat$Description); table(dat$Description)
dim(dat) 

train <- head(dat, ntrain)
test <- tail(dat, ntest)
train %>% count(Description)
test %>% count(Description)

## Run RF
#tmp <- foreach (nsamp = svec) %dopar% { # use for CNN!!
for (nsamp in svec) { # use for CNN!!
    set.seed(12345 + nsamp)
    samp <- sample(nrow(train), nsamp, replace=FALSE)

    (rf.time <- system.time(rf <- ranger(Description ~ ., data=train[samp,], num.trees=ntree[1])))
    rf.test <- predict(rf, test %>% select(-Description))
    (rf.error <- sum(rf.test$pred != test$Description) / nrow(test))
        
    cat(cname, ": ", nsamp, ": rf10.err = ", round(rf10.error,2), ", rf10.time = ", rf10.time[3], "\n")
    save(nsamp, ntree, rf.time, rf.error, file=paste0(cname, "/out-rf-",nsamp,"-ntrain",ntrain,"-ntest",ntest,".Rbin"))
}

## Run CNN
ind.y <- which(dat.cnn$Y %in% use)
train.x <- dat.cnn$dat[ind.y,,,][1:ntrain,,,]
train.y <- head(dat.cnn$Y[ind.y], ntrain); table(train.y)
train <- list(x=train.x, y=train.y)
test.x <- dat.cnn$dat[ind.y,,,][(length(ind.y)-ntest+1):length(ind.y),,,]
test.y <- tail(dat.cnn$Y[ind.y], ntest); table(test.y)
test <- list(x=test.x, y=test.y)

for (nsamp in svec) { # use for CNN!!
    set.seed(12345 + nsamp)
    
    train.s <- train
    test.s <- test
    
    samp <- sample(nrow(train.s$x), nsamp, replace=FALSE)
    train.s <- list(x=train.s$x[samp,,,], y=train.s$y[samp]); table(train.s$y)
    
    cnn.fname <- paste0(cname,"/out-cnn-",nsamp,"-ntrain",ntrain,"-ntest",ntest,".Rbin")

    cnn3.time <- system.time(model3 <- cnn(train.s, complexity=3, epoch=epoch, verbose=0))
    cnn3.error <- 1 - evaluate(model3, test.s$x, test.s$y, verbose = 0)$accuracy

    cnn8.time <- system.time(model8 <- cnn(train.s, complexity=8, epoch=epoch, verbose=0))
    cnn8.error <- 1 - evaluate(model8, test.s$x, test.s$y, verbose = 0)$accuracy
    
    cat(cname, ": ", nsamp, ": cnn3.err = ", round(cnn3.error,2), ": cnn3.time = ", cnn3.time[3], "\n")
    cat(cname, ": ", nsamp, ": cnn8.err = ", round(cnn8.error,2), ": cnn8.time = ", cnn8.time[3], "\n")
    cat("----------------------------------------------------------------------------------\n")

    save(nsamp, epoch, cnn3.error, cnn8.error, file=cnn.fname)
}