1 2
from typing import List, Mapping, Optional
2

3 2
import torch
4 2
import torch.nn.functional as F
5

6 2
Outputs = Mapping[str, List[torch.Tensor]]
7

8

9 2
def cross_entropy_with_probs(
10
    input: torch.Tensor,
11
    target: torch.Tensor,
12
    weight: Optional[torch.Tensor] = None,
13
    reduction: str = "mean",
14
) -> torch.Tensor:
15
    """Calculate cross-entropy loss when targets are probabilities (floats), not ints.
16

17
    PyTorch's F.cross_entropy() method requires integer labels; it does accept
18
    probabilistic labels. We can, however, simulate such functionality with a for loop,
19
    calculating the loss contributed by each class and accumulating the results.
20
    Libraries such as keras do not require this workaround, as methods like
21
    "categorical_crossentropy" accept float labels natively.
22

23
    Note that the method signature is intentionally very similar to F.cross_entropy()
24
    so that it can be used as a drop-in replacement when target labels are changed from
25
    from a 1D tensor of ints to a 2D tensor of probabilities.
26

27
    Parameters
28
    ----------
29
    input
30
        A [num_points, num_classes] tensor of logits
31
    target
32
        A [num_points, num_classes] tensor of probabilistic target labels
33
    weight
34
        An optional [num_classes] array of weights to multiply the loss by per class
35
    reduction
36
        One of "none", "mean", "sum", indicating whether to return one loss per data
37
        point, the mean loss, or the sum of losses
38

39
    Returns
40
    -------
41
    torch.Tensor
42
        The calculated loss
43

44
    Raises
45
    ------
46
    ValueError
47
        If an invalid reduction keyword is submitted
48
    """
49 2
    num_points, num_classes = input.shape
50
    # Note that t.new_zeros, t.new_full put tensor on same device as t
51 2
    cum_losses = input.new_zeros(num_points)
52 2
    for y in range(num_classes):
53 2
        target_temp = input.new_full((num_points,), y, dtype=torch.long)
54 2
        y_loss = F.cross_entropy(input, target_temp, reduction="none")
55 2
        if weight is not None:
56 2
            y_loss = y_loss * weight[y]
57 2
        cum_losses += target[:, y].float() * y_loss
58

59 2
    if reduction == "none":
60 2
        return cum_losses
61 2
    elif reduction == "mean":
62 2
        return cum_losses.mean()
63 2
    elif reduction == "sum":
64 2
        return cum_losses.sum()
65
    else:
66 2
        raise ValueError("Keyword 'reduction' must be one of ['none', 'mean', 'sum']")

Read our documentation on viewing source code .

Loading