1
#' Logistic regression classifier for texts
2
#'
3
#' Fits a fast penalized maximum likelihood estimator to predict discrete
4
#' categories from sparse [dfm][quanteda::dfm] objects. Using the \pkg{glmnet}
5
#' package, the function computes the regularization path for the lasso or
6
#' elasticnet penalty at a grid of values for the regularization parameter
7
#' lambda.  This is done automatically by testing on several folds of the data
8
#' at estimation time.
9
#' @param x the \link{dfm} on which the model will be fit.  Does not need to
10
#'   contain only the training documents.
11
#' @param y vector of training labels associated with each document identified
12
#'   in \code{train}.  (These will be converted to factors if not already
13
#'   factors.)
14
#' @param ... additional arguments passed to [`cv.glmnet()`][glmnet::cv.glmnet()]
15
#' @seealso [`cv.glmnet()`][glmnet::cv.glmnet()], [predict.textmodel_lr()],
16
#'   [coef.textmodel_lr()]
17
#' @references
18
#' Friedman, J., Hastie, T., & Tibshirani, R. (2010). [Regularization Paths for
19
#' Generalized Linear Models via Coordinate
20
#' Descent](http://dx.doi.org/10.18637/jss.v033.i01). _Journal of Statistical
21
#' Software_ 33(1), 1-22.
22
#' @examples
23
#' ## Example from 13.1 of _An Introduction to Information Retrieval_
24
#' corp <- quanteda::corpus(c(d1 = "Chinese Beijing Chinese",
25
#'                            d2 = "Chinese Chinese Shanghai",
26
#'                            d3 = "Chinese Macao",
27
#'                            d4 = "Tokyo Japan Chinese",
28
#'                            d5 = "London England Chinese",
29
#'                            d6 = "Chinese Chinese Chinese Tokyo Japan"),
30
#'                          docvars = data.frame(train = factor(c("Y", "Y", "Y",
31
#'                                                                "N", "N", NA))))
32
#' dfmat <- quanteda::dfm(corp, tolower = FALSE)
33
#'
34
#' ## simulate bigger sample as classification on small samples is problematic
35
#' set.seed(1)
36
#' dfmat <- quanteda::dfm_sample(dfmat, 50, replace = TRUE)
37
#'
38
#' ## train model
39
#' (tmod1 <- textmodel_lr(dfmat, quanteda::docvars(dfmat, "train")))
40
#' summary(tmod1)
41
#' coef(tmod1)
42
#'
43
#' ## predict probability and classes
44
#' predict(tmod1, type = "prob")
45
#' predict(tmod1)
46
#' @export
47
textmodel_lr <- function(x, y, ...) {
48 1
    UseMethod("textmodel_lr")
49
}
50

51
#' @export
52
textmodel_lr.default <- function(x, y, ...) {
53 0
    stop(quanteda:::friendly_class_undefined_message(class(x), "textmodel_lr"))
54
}
55

56
#' @export
57
#' @importFrom glmnet cv.glmnet
58
textmodel_lr.dfm <- function(x, y, ...) {
59

60 1
    x <- as.dfm(x)
61 0
    if (!sum(x)) stop(quanteda:::message_error("dfm_empty"))
62 1
    call <- match.call()
63

64
    # exclude NA in training labels
65 1
    x_train <- suppressWarnings(
66 1
        dfm_trim(x[!is.na(y), ], min_termfreq = .0000000001,
67 1
                 termfreq_type = "prop")
68
    )
69 1
    y_train <- y[!is.na(y)]
70

71 1
    n_class <- if (is.factor(y_train)) {
72 1
        length(levels(y_train))
73
    } else {
74 0
        length(unique(y_train))
75
    }
76

77 1
    family <- if (n_class > 2) {
78 1
        "multinomial"
79 1
    } else if (n_class > 1) {
80 1
        "binomial"
81
    } else {
82 0
        stop("y must at least have two different labels.")
83
    }
84

85 1
    lrfitted <- glmnet::cv.glmnet(
86 1
        x = x_train,
87 1
        y = y_train,
88 1
        family = family,
89 1
        maxit = 10000,
90
        ...
91
    )
92

93 1
    result <- list(
94 1
        x = x, y = y,
95 1
        algorithm = paste(family, "logistic regression"),
96 1
        type = family,
97 1
        classnames = lrfitted[["glmnet.fit"]][["classnames"]],
98 1
        lrfitted = lrfitted,
99 1
        call = call
100
    )
101 1
    class(result) <- c("textmodel_lr", "textmodel", "list")
102 1
    result
103
}
104

105
# helper methods ----------------
106

107
#' Prediction from a fitted textmodel_lr object
108
#'
109
#' \code{predict.textmodel_lr()} implements class predictions from a fitted
110
#' logistic regression model.
111
#' @param object a fitted logistic regression textmodel
112
#' @param newdata dfm on which prediction should be made
113
#' @param type the type of predicted values to be returned; see Value
114
#' @param force make newdata's feature set conformant to the model terms
115
#' @param ... not used
116
#' @return `predict.textmodel_lr()` returns either a vector of class
117
#'   predictions for each row of `newdata` (when `type = "class"`), or
118
#'   a document-by-class matrix of class probabilities (when `type =
119
#'   "probability"``).
120
#' @seealso [textmodel_lr()]
121
#' @keywords textmodel internal
122
#' @importFrom stats predict
123
#' @method predict textmodel_lr
124
#' @export
125
predict.textmodel_lr <- function(object, newdata = NULL,
126
                                 type = c("class", "probability"),
127
                                 force = TRUE, ...) {
128 1
    type <- match.arg(type)
129 1
    if (type == "probability") {
130 1
        type <- "response"
131
    }
132

133 1
    if (!is.null(newdata)) {
134 1
        data <- as.dfm(newdata)
135
    } else {
136 0
        data <- as.dfm(object$x)
137
    }
138

139 1
    model_featnames <- colnames(object$x)
140 1
    data <- if (is.null(newdata)) {
141 0
        suppressWarnings(force_conformance(data, model_featnames, force))
142
    } else {
143 1
        force_conformance(data, model_featnames, force)
144
    }
145

146 1
    pred_y <- predict(
147 1
        object$lrfitted,
148 1
        newx = data,
149 1
        type = type,
150
        ...
151
    )
152 1
    if (type == "class") {
153 1
        pred_y <- as.factor(pred_y)
154 1
        names(pred_y) <-  quanteda::docnames(data)
155 1
    } else if (type == "response") {
156 1
        if (ncol(pred_y) == 1) {
157 1
            pred_y <- cbind(
158 1
                pred_y[, 1],
159 1
                1 - pred_y[, 1]
160
            )
161 1
            colnames(pred_y) <- rev(object$classnames)
162
        } else {
163 1
            pred_y <- pred_y[, , 1]
164
        }
165
    }
166 1
    pred_y
167
}
168

169
#' @export
170
#' @method print textmodel_lr
171
print.textmodel_lr <- function(x, ...) {
172 1
    cat("\nCall:\n")
173 1
    print(x$call)
174 1
    cat("\n",
175 1
        format(quanteda::ndoc(x$x), big.mark = ","), " training documents; ",
176 1
        format(quanteda::nfeat(x$x), big.mark = ","), " fitted features",
177 1
        ".\n",
178 1
        "Method: ", x$algorithm, "\n",
179 1
        sep = "")
180
}
181

182
#' @rdname predict.textmodel_lr
183
#' @method coef textmodel_lr
184
#' @return `coef.textmodel_lr()` returns a (sparse) matrix of coefficients for
185
#'   each feature, computed at the value of the penalty parameter fitted in the
186
#'   model.  For binary outcomes, results are returned only for the class
187
#'   corresponding to the second level of the factor response; for multinomial
188
#'   outcomes, these are computed for each class.
189
#' @importFrom stats coef
190
#' @export
191
coef.textmodel_lr <- function(object, ...) {
192 1
    if (object$type == "binomial") {
193 1
        out <- coef(object$lrfitted)
194 1
        colnames(out) <- object$classnames[2]
195 1
    } else if (object$type == "multinomial") {
196 1
        out <- coef(object$lrfitted)
197 1
        out <- do.call(cbind, out)
198 1
        colnames(out) <- object$classnames
199
    }
200 1
    out
201
}
202

203
#' @rdname predict.textmodel_lr
204
#' @method coefficients textmodel_lr
205
#' @importFrom stats coefficients
206
#' @export
207
coefficients.textmodel_lr <- function(object, ...) {
208 0
    UseMethod("coef")
209
}
210

211
#' summary method for textmodel_lr objects
212
#' @param object output from [textmodel_lr()]
213
#' @param n how many coefficients to print before truncating
214
#' @param ... additional arguments not used
215
#' @keywords textmodel internal
216
#' @method summary textmodel_lr
217
#' @export
218
summary.textmodel_lr <- function(object, n = 30, ...) {
219 0
    result <- list(
220 0
        "call" = object$call,
221
        # "folds" = object$nfolds,
222 0
        "lambda min" = object$lrfitted$lambda.min,
223 0
        "lambda 1se" = object$lrfitted$lambda.1se,
224 0
        "estimated.feature.scores" = as.matrix(head(coef(object), n))
225
    )
226 0
    as.summary.textmodel(result)
227
}

Read our documentation on viewing source code .

Loading