library("igraph")
library("fpc")
library("popbio")

source("http://www.cis.jhu.edu/~parky/HSBM/ccc_utils.r")

set.seed(1234)
R <- 8 
(nR.vec <- 200+100*ceiling(5*runif(R)))
# [1] 300 600 600 600 700 600 300 400
## Motif labels for the SBM
(motif.vec <- sample(1:3, R, replace = TRUE))
# [1] 2 2 3 2 1 3 1 3
vlab <- rep(1:R, times=nR.vec)
mlab <- rep(motif.vec, times=nR.vec)

p <- 0.01

B.list <- list()
rho.list <- list()
for(r in 1:R){
    if(motif.vec[r] == 1){
        B.list[[r]] <- matrix(0.25,3,3);
        diag(B.list[[r]]) <- .4
        rho.list[[r]] <- c(0.25, 0.5, 0.25)

    } else if(motif.vec[r] == 2) {
        B.list[[r]] <- matrix(0.2,3,3);
        diag(B.list[[r]]) <- .25; B.list[[r]][2,2] <- .8
        rho.list[[r]] <- c(0.3, 0.4, 0.3)

    } else {
        B.list[[r]] <- matrix(0.25,3,3);
        diag(B.list[[r]]) <- 0.3; B.list[[r]][3,3] <- .7
        rho.list[[r]] <- c(0.4, 0.2, 0.4)
    }
}
rho.list
# [[1]]
# [1] 0.3 0.4 0.3
# 
# [[2]]
# [1] 0.3 0.4 0.3
# 
# [[3]]
# [1] 0.4 0.2 0.4
# 
# [[4]]
# [1] 0.3 0.4 0.3
# 
# [[5]]
# [1] 0.25 0.50 0.25
# 
# [[6]]
# [1] 0.4 0.2 0.4
# 
# [[7]]
# [1] 0.25 0.50 0.25
# 
# [[8]]
# [1] 0.4 0.2 0.4
B.list
# [[1]]
#      [,1] [,2] [,3]
# [1,] 0.25  0.2 0.20
# [2,] 0.20  0.8 0.20
# [3,] 0.20  0.2 0.25
# 
# [[2]]
#      [,1] [,2] [,3]
# [1,] 0.25  0.2 0.20
# [2,] 0.20  0.8 0.20
# [3,] 0.20  0.2 0.25
# 
# [[3]]
#      [,1] [,2] [,3]
# [1,] 0.30 0.25 0.25
# [2,] 0.25 0.30 0.25
# [3,] 0.25 0.25 0.70
# 
# [[4]]
#      [,1] [,2] [,3]
# [1,] 0.25  0.2 0.20
# [2,] 0.20  0.8 0.20
# [3,] 0.20  0.2 0.25
# 
# [[5]]
#      [,1] [,2] [,3]
# [1,] 0.40 0.25 0.25
# [2,] 0.25 0.40 0.25
# [3,] 0.25 0.25 0.40
# 
# [[6]]
#      [,1] [,2] [,3]
# [1,] 0.30 0.25 0.25
# [2,] 0.25 0.30 0.25
# [3,] 0.25 0.25 0.70
# 
# [[7]]
#      [,1] [,2] [,3]
# [1,] 0.40 0.25 0.25
# [2,] 0.25 0.40 0.25
# [3,] 0.25 0.25 0.40
# 
# [[8]]
#      [,1] [,2] [,3]
# [1,] 0.30 0.25 0.25
# [2,] 0.25 0.30 0.25
# [3,] 0.25 0.25 0.70
## subblock label
svec <- lapply(1:R, function(x) rho.list[[x]]*nR.vec[x])
slab <- unlist(sapply(1:R, function(x) rep(1:3,times=svec[[x]])+(x-1)*3))
    
g <- sample_hierarchical_sbm(sum(nR.vec), nR.vec, rho.list, B.list, p)
mycol <- rainbow(max(mlab))[motif.vec]
plotmemb(g[],vlab,main=paste("A, R = ", max(vlab), ", m = 3"),drawborder=TRUE,lwd=.01,lcol=mycol,lwdb=2)
## Step 1
dmax <- 50
Xhat <- embed_adjacency_matrix(g,dmax,options=list(maxiter=10000))$X
eval <- sqrt(colSums(Xhat^2))
(dhat <- getElbows(eval,3,plot=F))
# [1]  8 14 16
dhat <- dhat[1]
#dhat <- 24

#sXhat <- Xhat[,1:dhat] / sqrt(rowSums(Xhat[,1:dhat]^2))
sXhat <- Xhat[,1:dhat]
Rmax <- 1.5*dhat
krange <- 2:Rmax
cl.out <- vlclustK(sXhat,krange)
membp <- cl.out$memb # Rhat = max(membp)
(Rhat <- max(membp))
# [1] 8
(tablep <- table(membp))
# membp
#   1   2   3   4   5   6   7   8 
# 600 600 600 300 400 700 300 600
mycol2 <- rainbow(Rhat)
plotmemb(g[],membp,main="",drawborder=TRUE,lwd=.01,lcol=mycol2,lwdb=2)
## Step 2
sigma <- 0.5
X.list <- reembed(g[], 3, membp)
S <- computeS(X.list, sigma)
S[S<0] <- 0
rownames(S) <- colnames(S) <- 1:Rhat
graphs.cluster <- pamk(S, diss = TRUE, krange = 2:(length(X.list)-1))
Yhat <- graphs.cluster$pamobject$clustering
(numc <- graphs.cluster$nc)
# [1] 3
## Step 3
Bhat.list <- list()
rhohat.list <- list()
for(i in 1:numc){
    idx.i <- which(Yhat == i)
    Xi <- NULL
    for(j in idx.i){
        if(is.null(Xi)){
            Xi <- X.list[[j]]
        } else {
            T <- find.transform(X.list[[j]], Xi)
            Xi <- rbind(Xi, X.list[[j]] %*% T)
        }
    }
    Xi.pamk <- pamk(Xi)
    Bi <- matrix(0, Xi.pamk$nc, Xi.pamk$nc)
    rhohat.list[[i]] <- as.vector(table(Xi.pamk$pamobj$cluster))/nrow(Xi)
    for(j1 in 1:Xi.pamk$nc){
        for(j2 in j1:Xi.pamk$nc){
            Bi[j1,j2] <- sum(Xi.pamk$pamobj$medoids[j1,]*Xi.pamk$pamobj$medoids[j2,])
            Bi[j2,j1] <- Bi[j1,j2]
        }
    }
    Bhat.list[[i]] <- Bi
}
Bhat.list
# [[1]]
#           [,1]      [,2]
# [1,] 0.2743059 0.2492344
# [2,] 0.2492344 0.7204546
# 
# [[2]]
#           [,1]      [,2]
# [1,] 0.2191020 0.2028399
# [2,] 0.2028399 0.7951436
# 
# [[3]]
#           [,1]      [,2]      [,3]
# [1,] 0.4050112 0.2684659 0.2626907
# [2,] 0.2684659 0.3924297 0.2535544
# [3,] 0.2626907 0.2535544 0.4080792
rhohat.list
# [[1]]
# [1] 0.6 0.4
# 
# [[2]]
# [1] 0.6 0.4
# 
# [[3]]
# [1] 0.251 0.258 0.491
suppressMessages(require(gplots))
suppressMessages(require(dendextend))
Dshat <- as.dist(S)
rhc <- as.dendrogram(hclust(Dshat,"ward.D2"))
hlwd <- 2
Rowv  <- Colv <- rhc %>% as.dendrogram %>%
    set("branches_k_color", k = numc) %>% set("branches_lwd", hlwd) %>%
        rotate_DendSer(ser_weight = Dshat)
heatmap(S,Rowv=Rowv,Colv=Colv,revC=TRUE,col=rev(heat.colors(255)),scale="none")