class: center, middle, inverse, title-slide # Part 5: Classification ### Sam Tyner ### TBD --- class: primary # Outline 1. What is classification? 2. Motivating Example 3. Features 4. A familiar example (logistic regression) 5. Trees `rpart` 6. Forests `randomForest` 7. K Nearest neighbors `caret::knn3` --- class: inverse # What is classification? --- class: primary # Definition Classification is - for determining the category a new observation belongs to - based on a **training** set of data with observations belonging to different categories - used on a **test** set of data where the categories are unknown The *response* variable is the category of the observation The *explanatory* variables are **features** of the observations --- class: primary # What we want We want to *learn* about the categories from the training set so we can *classify* the (future) observations from the test set. --- class: inverse # Motivating Example --- class: primary # Handwritten digits <img src="classification_files/figure-html/dat-1.png" style="display: block; margin: auto;" /> --- class: primary # Classification It's easy for humans to identify the digit written in each of those images. But... - we're not as fast as computers - we get fatigued - we get bored - we make mistakes It's a lot harder for computers to identify a written digit. But... - they're freaky fast - they (almost) never get tired - they never get bored of mundane tasks - they don't make mistakes --- class: primary # Goal We are going to use statistics to classify the handwritten digits We need: - lots of training data - feature sets Download the training data [here](https://github.com/CSAFE-ISU/REU-materials-2018/blob/master/slides/statistics/dat/smallish-digit-train.csv) or read it in using this url: ``` url <- "https://raw.githubusercontent.com/CSAFE-ISU/REU-materials-2018/master/slides/statistics/dat/smallish-digit-train.csv" ``` --- class: inverse # Features --- class: primary # Definition **Features** are a subset of variables or combinations of variables that are used to build a model Feature extraction and feature selection are important because: - they removes redundant information from the data - they help avoid overfitting - they can make the model easier to interpret We can use algorithms to extract/select features (later) or we can create them on our own --- class: primary # Examples For handwritten digits: - means of the grayscale values in rows and/or columns - number of rows and/or columns that are all zeroes For Dr. Hofmann's bullet data: - cross-correlation factor - continuous matching striae --- class: primary # Number of all 0s Let's first build a simple feature: the number of rows in the pixel image that are all 0 (all white) First, we have to get the data into the right form: ```r library(tidyverse) train <- read_csv("dat/smallish-digit-train.csv") # function to convert each row to a matrix to_raster <- function(arow){ pixels <- as.numeric(arow) raster <- data.frame(matrix(data = pixels, nrow=28, ncol=28, byrow = T), stringsAsFactors = F) %>% mutate(row = 1:n()) %>% gather(key = col, value = color, -row) %>% mutate(col = parse_number(as.character(col)), color = 255 - as.numeric(as.character(color)), color2 = rgb(color, color, color, maxColorValue = 255)) return(raster) } ``` --- class: primary # Number of all 0s ```r train_raster <- train %>% group_by(label) %>% mutate(rep = row_number()) %>% ungroup() %>% nest(-c(label, rep)) %>% mutate(rasters = map(data, to_raster)) %>% select(-data) %>% unnest() train_feature1 <- train_raster %>% group_by(label, rep, row) %>% summarise(zeroes = sum(color == 255)) %>% ungroup() %>% group_by(label, rep) %>% summarise(n_0 = sum(zeroes == 28)) ``` --- class: inverse # Logistic Regression (remix) --- class: primary # Model .small[ Need 2 categories, so let's pick a number: 5 "Success" = 5, "Failure" = not 5] ```r train_feature1 <- train_feature1 %>% mutate(label2 = as.numeric(label == 5)) lr1 <- glm(label2 ~ n_0, data = train_feature1, family = "binomial") coef(lr1) ``` ``` ## (Intercept) n_0 ## -7.4333442 0.6178515 ``` ```r exp(coef(lr1)) ``` ``` ## (Intercept) n_0 ## 0.0005912071 1.8549383858 ``` .small[For each additional row of all 0s, the odds that the number is 5 increase by 1.85. (About twice as likely to be 5 for every additional row of all 0s)] --- class: primary # Prediction ```r test <- read_csv("dat/small-digit-train.csv") test_raster <- test %>% group_by(label) %>% mutate(rep = row_number()) %>% ungroup() %>% nest(-c(label, rep)) %>% mutate(rasters = map(data, to_raster)) %>% select(-data) %>% unnest() test_feature1 <- test_raster %>% group_by(label, rep, row) %>% summarise(zeroes = sum(color == 255)) %>% ungroup() %>% group_by(label, rep) %>% summarise(n_0 = sum(zeroes == 28)) test_feature1 <- test_feature1 %>% ungroup() %>% mutate(label2 = as.numeric(label == 5), pred = predict(lr1, newdata = test_feature1, type="response")) table(test_feature1$label2, round(test_feature1$pred, 3)) ``` ``` ## ## 0.077 0.133 0.222 0.495 ## 0 85 1 3 1 ## 1 9 0 1 0 ``` --- class: primary # Assessment Not a very good model: - few "successes" relative to the number of failures (only 10% are successes) - model is only based on one aggregate feature - need more! - want more classes, not just "success" and "failure" Other options: - trees - forests - k-nearest neighbors --- class: inverse # Decision trees --- class: primary # Decision trees Classification with a sequence of binary decisions. Example: who survived the sinking of the Titanic? ([Wikipedia](https://commons.wikimedia.org/wiki/File:CART_tree_titanic_survivors.png)) <div class="figure" style="text-align: center"> <img src="classification_files/figure-html/titanic-1.png" alt="A tree showing survival of passengers on the Titanic ("sibsp" is the number of spouses or siblings aboard). The figures under the leaves show the probability of survival and the percentage of observations in the leaf." /> <p class="caption">A tree showing survival of passengers on the Titanic ("sibsp" is the number of spouses or siblings aboard). The figures under the leaves show the probability of survival and the percentage of observations in the leaf.</p> </div> --- class: primary # Digits ```r set.seed(123456) # fit tree tree1 <- rpart(as.factor(label) ~ ., data=train, cp=.02) # plot tree prp(tree1, under = T, extra=104, tweak = 1.5) ``` --- class: primary # Digits <img src="classification_files/figure-html/tree2-1.png" width="90%" style="display: block; margin: auto;" /> --- class: primary # Classifier regions First 2 splits only: how does the model think? <img src="classification_files/figure-html/classregion-1.png" style="display: block; margin: auto;" /> --- class: primary # Confusion matrix A **confusion matrix** shows you where the model gets it wrong. This can be done for both the training and testing data. Training: ``` ## pred_tree1 ## 0 1 2 3 4 5 6 7 8 9 ## 0 0.75 0.01 0.00 0.01 0.00 0.14 0.00 0.03 0.04 0.02 ## 1 0.01 0.89 0.04 0.02 0.02 0.00 0.00 0.01 0.00 0.01 ## 2 0.07 0.11 0.54 0.02 0.03 0.00 0.06 0.07 0.08 0.02 ## 3 0.03 0.08 0.00 0.76 0.02 0.01 0.00 0.02 0.03 0.05 ## 4 0.00 0.02 0.00 0.04 0.77 0.01 0.04 0.00 0.01 0.11 ## 5 0.09 0.14 0.00 0.31 0.01 0.28 0.01 0.06 0.07 0.03 ## 6 0.11 0.11 0.09 0.00 0.11 0.04 0.38 0.02 0.09 0.05 ## 7 0.00 0.08 0.01 0.00 0.06 0.02 0.00 0.71 0.00 0.12 ## 8 0.01 0.12 0.06 0.06 0.00 0.00 0.02 0.03 0.59 0.11 ## 9 0.00 0.04 0.03 0.01 0.01 0.00 0.00 0.10 0.00 0.81 ``` Which class is it best at predicting? Which is is worst at predicting? --- class: primary # Confusion matrix Testing: ``` ## pred_test ## 0 1 2 3 4 5 6 7 8 9 ## 0 0.8 0.0 0.0 0.0 0.0 0.1 0.0 0.1 0.0 0.0 ## 1 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ## 2 0.1 0.0 0.7 0.0 0.0 0.0 0.0 0.1 0.1 0.0 ## 3 0.0 0.3 0.0 0.7 0.0 0.0 0.0 0.0 0.0 0.0 ## 4 0.2 0.0 0.0 0.0 0.7 0.0 0.0 0.0 0.0 0.1 ## 5 0.3 0.1 0.0 0.4 0.0 0.1 0.0 0.0 0.1 0.0 ## 6 0.1 0.0 0.0 0.0 0.0 0.2 0.3 0.0 0.4 0.0 ## 7 0.0 0.1 0.0 0.0 0.1 0.0 0.0 0.4 0.0 0.4 ## 8 0.0 0.4 0.1 0.0 0.1 0.0 0.0 0.0 0.4 0.0 ## 9 0.0 0.2 0.0 0.0 0.2 0.0 0.1 0.0 0.0 0.5 ``` Which class is it best at predicting? Which is it worst at predicting? --- class: primary # Key Facts - randomness: if you don't set a "random seed", you'll end up with different results every time - trees with too many branches overfit the data --- class: inverse # Forests --- class: primary # Random Forests Forests are made up of trees: - built on different (random) subsets of data - use different variables - trees are typically small - each tree does something well, but does other things poorly - the trees work well together, better than one large tree (like a musical ensemble) - classes are assigned based on the most common group from the many trees (e.g. if an observation is classified as 5 in 6/10 trees and as 6 in 4/10 trees, the forest will classify it as 5) --- class: primary # Digits forest ```r library(randomForest) digit_forest <- randomForest(as.factor(label) ~ . , data = train, ntree= 1000) digit_forest$confusion ``` ``` ## 0 1 2 3 4 5 6 7 8 9 class.error ## 0 95 0 0 0 0 1 2 0 1 1 0.05 ## 1 0 98 1 0 0 0 1 0 0 0 0.02 ## 2 0 2 90 1 1 0 1 3 1 1 0.10 ## 3 1 1 3 80 1 8 0 1 3 2 0.20 ## 4 0 0 0 0 91 0 1 0 0 8 0.09 ## 5 2 2 1 6 0 83 4 0 1 1 0.17 ## 6 1 0 1 0 1 5 91 1 0 0 0.09 ## 7 0 1 0 0 0 1 0 94 1 3 0.06 ## 8 0 3 3 3 2 3 2 0 82 2 0.18 ## 9 0 0 0 0 2 1 0 5 0 92 0.08 ``` --- class: primary # Decisions <img src="classification_files/figure-html/unnamed-chunk-1-1.png" style="display: block; margin: auto;" /> --- class: primary # Test data ``` ## pred_test ## 0 1 2 3 4 5 6 7 8 9 ## 0 10 0 0 0 0 0 0 0 0 0 ## 1 0 10 0 0 0 0 0 0 0 0 ## 2 0 0 9 0 0 0 1 0 0 0 ## 3 0 0 0 9 0 1 0 0 0 0 ## 4 0 0 0 0 9 0 1 0 0 0 ## 5 0 0 0 0 0 10 0 0 0 0 ## 6 0 0 0 0 0 0 10 0 0 0 ## 7 0 0 0 0 1 0 0 9 0 0 ## 8 0 0 1 0 0 0 0 0 9 0 ## 9 0 1 0 0 2 1 0 0 0 6 ``` --- class: primary # Var. importance Which pixel value(s) has/have the most influence over the model? ```r varImpPlot(digit_forest) ``` <img src="classification_files/figure-html/import-1.png" style="display: block; margin: auto;" /> --- class: secondary ```r ggplot(data = train, aes(x = pixel350, y = pixel378)) + geom_text(aes(label = label)) + theme_csafe() ``` <img src="classification_files/figure-html/import2-1.png" style="display: block; margin: auto;" /> --- class: primary # Key Facts - forests are made up of decision trees - majority rules - trees are built using different random subsets of the data - trees work better in groups --- class: inverse # K Nearest neighbors --- class: primary # Everyone's doing it! ![](img/bridgejump.png) --- class: primary # KNN Idea Classify points into groups based on what groups their `\(k\)` nearest neighbors are in - `\(k\)` is the number of neighbors to look at - larger `\(k\)` means less precision - smaller `\(k\)` means more precision - choose `\(k\)` to give most accurate model without overfitting --- class: primary # Pick k Use brute force: do many `\(k\)` and pick the `\(k\)` value that results in lowest error rate on test set. ```r library(caret) knn_errors <- NULL for (i in 1:30){ digits_KNN <- knn3(as.factor(label) ~ ., data = train, k = i) pred_KNN <- predict(digits_KNN, newdata = test, type = "class") error_rate <- mean(as.factor(test$label) != pred_KNN) score <- c(k = i, error = error_rate) knn_errors <- rbind(knn_errors, score) } ggplot(data = data.frame(knn_errors), aes(x = k, y = error)) + geom_line() + geom_point() + theme_csafe() ``` <img src="classification_files/figure-html/knn-1.png" style="display: block; margin: auto;" /> --- class: primary # K = 4 Fit the model and look at class probabilities: ```r digits_k4 <- knn3(as.factor(label) ~ ., data = train, k = 4) pred_k4<- predict(digits_k4, newdata = test) data.frame(pred_k4) %>% mutate(id = row_number()) %>% gather(key = class, value = prob, -id) %>% mutate(class=parse_number(class)) %>% ggplot() + geom_tile(aes(x = class, y = id, fill = prob)) + scale_fill_continuous(low = "white", high = "black") + theme_csafe() + theme(aspect.ratio = 1) ``` <img src="classification_files/figure-html/knn4-1.png" style="display: block; margin: auto;" /> --- class: primary # Confusion ```r pred_k4.2 <- predict(digits_k4, newdata = test, type = "class") table(test$label, pred_k4.2) ``` ``` ## pred_k4.2 ## 0 1 2 3 4 5 6 7 8 9 ## 0 10 0 0 0 0 0 0 0 0 0 ## 1 0 10 0 0 0 0 0 0 0 0 ## 2 0 1 7 0 0 0 0 2 0 0 ## 3 0 0 0 8 0 1 0 0 0 1 ## 4 0 0 0 0 8 0 1 0 0 1 ## 5 0 0 0 0 0 10 0 0 0 0 ## 6 0 0 0 0 0 0 10 0 0 0 ## 7 0 0 0 0 0 0 0 10 0 0 ## 8 0 1 0 0 0 1 0 0 7 1 ## 9 0 1 0 0 0 1 0 0 0 8 ```