1
#' Linear SVM classifier for texts
2
#'
3
#' Fit a fast linear SVM classifier for texts, using the
4
#' \pkg{LiblineaR} package.
5
#' @param x the \link{dfm} on which the model will be fit.  Does not need to
6
#'   contain only the training documents.
7
#' @param y vector of training labels associated with each document identified
8
#'   in \code{train}.  (These will be converted to factors if not already
9
#'   factors.)
10
#' @param weight weights for different classes for imbalanced training sets,
11
#'   passed to \code{wi} in \code{\link[LiblineaR]{LiblineaR}}. \code{"uniform"}
12
#'   uses default; \code{"docfreq"} weights by the number of training examples,
13
#'   and \code{"termfreq"} by the relative sizes of the training classes in
14
#'   terms of their total lengths in tokens.
15
#' @param ... additional arguments passed to \code{\link[LiblineaR]{LiblineaR}}
16
#' @references
17
#' R. E. Fan, K. W. Chang, C. J. Hsieh, X. R. Wang, and C. J. Lin. (2008)
18
#' LIBLINEAR: A Library for Large Linear Classification.
19
#' \emph{Journal of Machine Learning Research} 9: 1871-1874.
20
#' \url{http://www.csie.ntu.edu.tw/~cjlin/liblinear}.
21
#' @seealso \code{\link[LiblineaR]{LiblineaR}}
22
#' @examples
23
#' # use party leaders for govt and opposition classes
24
#' docvars(data_corpus_irishbudget2010, "govtopp") <-
25
#'     c(rep(NA, 4), "Govt", "Opp", NA, "Opp", NA, NA, NA, NA, NA, NA)
26
#' dfmat <- dfm(data_corpus_irishbudget2010)
27
#' tmod <- textmodel_svm(dfmat, y = docvars(dfmat, "govtopp"))
28
#' predict(tmod)
29
#' predict(tmod, type = "probability")
30
#'
31
#' # multiclass problem - all party leaders
32
#' tmod2 <- textmodel_svm(dfmat,
33
#'     y = c(rep(NA, 3), "SF", "FF", "FG", NA, "LAB", NA, NA, "Green", rep(NA, 3)))
34
#' predict(tmod2)
35
#' predict(tmod2, type = "probability")
36
#' @export
37
textmodel_svm <- function(x, y, weight = c("uniform", "docfreq", "termfreq"), ...) {
38 1
    UseMethod("textmodel_svm")
39
}
40

41
#' @export
42
textmodel_svm.default <- function(x, y, weight = c("uniform", "docfreq", "termfreq"), ...) {
43 0
    stop(quanteda:::friendly_class_undefined_message(class(x), "textmodel_svm"))
44
}
45

46
#' @export
47
#' @importFrom LiblineaR LiblineaR
48
#' @importFrom SparseM as.matrix.csr
49
textmodel_svm.dfm <- function(x, y, weight = c("uniform", "docfreq", "termfreq"), ...) {
50 1
    x <- as.dfm(x)
51 0
    if (!sum(x)) stop(quanteda:::message_error("dfm_empty"))
52 1
    call <- match.call()
53 1
    weight <- match.arg(weight)
54

55
    # exclude NA in training labels
56 1
    x_train <- suppressWarnings(
57 1
        dfm_trim(x[!is.na(y), ], min_termfreq = .0000000001, termfreq_type = "prop")
58
    )
59 1
    y_train <- y[!is.na(y)]
60
    # remove zero-variance features
61 1
    constant_features <- which(apply(x_train, 2, stats::var) == 0)
62 0
    if (length(constant_features)) x_train <- x_train[, -constant_features]
63

64
    # set wi depending on weight value
65 1
    if (weight == "uniform") {
66 1
        wi <- NULL
67 1
    } else if (weight == "docfreq") {
68 1
        wi <- prop.table(table(y_train))
69 1
    } else if (weight == "termfreq") {
70 1
        wi <- rowSums(dfm_group(x_train, y_train))
71 1
        wi <- wi / sum(wi)
72
    }
73

74 1
    svmlinfitted <- LiblineaR::LiblineaR(as.matrix.csr.dfm(x_train),
75 1
                                         target = y_train, wi = wi, ...)
76 1
    colnames(svmlinfitted$W)[seq_along(featnames(x_train))] <- featnames(x_train)
77 1
    result <- list(
78 1
        x = x, y = y,
79 1
        weights = svmlinfitted$W,
80 1
        algorithm = svmlinfitted$TypeDetail,
81 1
        type = svmlinfitted$Type,
82 1
        classnames = svmlinfitted$ClassNames,
83 1
        bias = svmlinfitted$Bias,
84 1
        svmlinfitted = svmlinfitted,
85 1
        call = call
86
    )
87 1
    class(result) <- c("textmodel_svm", "textmodel", "list")
88 1
    result
89
}
90

91
# helper methods ----------------
92

93
#' Prediction from a fitted textmodel_svm object
94
#'
95
#' \code{predict.textmodel_svm()} implements class predictions from a fitted
96
#' linear SVM model.
97
#' @param object a fitted linear SVM textmodel
98
#' @param newdata dfm on which prediction should be made
99
#' @param type the type of predicted values to be returned; see Value
100
#' @param force make newdata's feature set conformant to the model terms
101
#' @param ... not used
102
#' @return \code{predict.textmodel_svm} returns either a vector of class
103
#'   predictions for each row of \code{newdata} (when \code{type = "class"}), or
104
#'   a document-by-class matrix of class probabilities (when \code{type =
105
#'   "probability"}).
106
#' @seealso \code{\link{textmodel_svm}}
107
#' @keywords textmodel internal
108
#' @importFrom SparseM as.matrix.csr
109
#' @export
110
predict.textmodel_svm <- function(object, newdata = NULL,
111
                                  type = c("class", "probability"),
112
                                  force = TRUE, ...) {
113 1
    quanteda:::unused_dots(...)
114

115 1
    type <- match.arg(type)
116

117 1
    if (!is.null(newdata)) {
118 0
        data <- as.dfm(newdata)
119
    } else {
120 1
        data <- as.dfm(object$x)
121
    }
122

123
    # the seq_along is because this will have an added term "bias" at end if bias > 0
124 1
    model_featnames <- colnames(object$weights)
125 1
    if (object$bias > 0) model_featnames <- model_featnames[-length(model_featnames)]
126

127 1
    data <- if (is.null(newdata))
128 1
        suppressWarnings(quanteda:::force_conformance(data, model_featnames, force))
129
    else
130 1
        quanteda:::force_conformance(data, model_featnames, force)
131

132 1
    pred_y <- predict(object$svmlinfitted,
133 1
                      newx = as.matrix.csr.dfm(data),
134 1
                      proba = (type == "probability"))
135

136 1
    if (type == "class") {
137 1
        pred_y <- pred_y$predictions
138 1
        names(pred_y) <- docnames(data)
139 1
    } else if (type == "probability") {
140 1
        pred_y <- pred_y$probabilities
141 1
        rownames(pred_y) <- docnames(data)
142
    }
143

144 1
    pred_y
145
}
146

147
#' @export
148
#' @method print textmodel_svm
149
print.textmodel_svm <- function(x, ...) {
150 1
    cat("\nCall:\n")
151 1
    print(x$call)
152 1
    cat("\n",
153 1
        format(length(na.omit(x$y)), big.mark = ","), " training documents; ",
154 1
        format(length(x$weights), big.mark = ","), " fitted features",
155 1
        ".\n",
156 1
        "Method: ", x$algorithm, "\n",
157 1
        sep = "")
158
}
159

160
#' summary method for textmodel_svm objects
161
#' @param object output from \code{\link{textmodel_svm}}
162
#' @param n how many coefficients to print before truncating
163
#' @param ... additional arguments not used
164
#' @keywords textmodel internal
165
#' @method summary textmodel_svm
166
#' @export
167
summary.textmodel_svm <- function(object, n = 30, ...) {
168 1
    result <- list(
169 1
        "call" = object$call,
170 1
        "estimated.feature.scores" = as.coefficients_textmodel(head(coef(object), n))
171
    )
172 1
    as.summary.textmodel(result)
173
}
174

175
#' @noRd
176
#' @method coef textmodel_svm
177
#' @importFrom stats coef
178
#' @export
179
coef.textmodel_svm <- function(object, ...) {
180 1
    object$weights
181
}
182

183
#' @noRd
184
#' @method coefficients textmodel_svm
185
#' @importFrom stats coefficients
186
#' @export
187
coefficients.textmodel_svm <- function(object, ...) {
188 0
    UseMethod("coef")
189
}
190

191
#' @export
192
#' @method print predict.textmodel_svm
193
print.predict.textmodel_svm <- function(x, ...) {
194 0
    print(unclass(x))
195
}
196

197
#' convert a dfm into a matrix.csr from SparseM package
198
#'
199
#' Utility to convert a dfm into a \link[SparseM]{matrix.csr} from the \pkg{SparseM} package.
200
#' @param x input \link{dfm}
201
#' @importFrom SparseM as.matrix.csr
202
#' @importFrom methods new
203
#' @method as.matrix.csr dfm
204
#' @keywords internal
205
as.matrix.csr.dfm <- function(x) {
206
    # convert first to column sparse format
207 1
    as.matrix.csr(new("matrix.csc",
208 1
                      ra = x@x,
209 1
                      ja = x@i + 1L,
210 1
                      ia = x@p + 1L,
211 1
                      dimension = x@Dim))
212
}

Read our documentation on viewing source code .

Loading