Skip to content

mapk

MAPK

Bases: Metric

Mean Average Precision at K (MAPK) metric.

This class inherits from the Metric class of the torchmetrics library.

Parameters:

Name Type Description Default
k int

The number of top predictions to consider when calculating the metric.

10
dist_sync_on_step bool

Whether to synchronize the metric states across processes during the forward pass.

False

Attributes:

Name Type Description
total Tensor

The cumulative sum of the metric scores across all batches.

count Tensor

The number of batches processed.

Examples:

>>> from spotpython.torch.mapk import MAPK
    import torch
    mapk = MAPK(k=2)
    target = torch.tensor([0, 1, 2, 2])
    preds = torch.tensor(
        [
            [0.5, 0.2, 0.2],  # 0 is in top 2
            [0.3, 0.4, 0.2],  # 1 is in top 2
            [0.2, 0.4, 0.3],  # 2 is in top 2
            [0.7, 0.2, 0.1],  # 2 isn't in top 2
        ]
    )
    mapk.update(preds, target)
    print(mapk.compute()) # tensor(0.6250)
Source code in spotpython/torch/mapk.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class MAPK(torchmetrics.Metric):
    """
    Mean Average Precision at K (MAPK) metric.

    This class inherits from the `Metric` class of the `torchmetrics` library.

    Args:
        k (int):
            The number of top predictions to consider when calculating the metric.
        dist_sync_on_step (bool):
            Whether to synchronize the metric states across processes during the forward pass.

    Attributes:
        total (torch.Tensor):
            The cumulative sum of the metric scores across all batches.
        count (torch.Tensor):
            The number of batches processed.

    Examples:
        >>> from spotpython.torch.mapk import MAPK
            import torch
            mapk = MAPK(k=2)
            target = torch.tensor([0, 1, 2, 2])
            preds = torch.tensor(
                [
                    [0.5, 0.2, 0.2],  # 0 is in top 2
                    [0.3, 0.4, 0.2],  # 1 is in top 2
                    [0.2, 0.4, 0.3],  # 2 is in top 2
                    [0.7, 0.2, 0.1],  # 2 isn't in top 2
                ]
            )
            mapk.update(preds, target)
            print(mapk.compute()) # tensor(0.6250)
    """

    def __init__(self, k=10, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.k = k
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, predicted: torch.Tensor, actual: torch.Tensor):
        """
        Update the state variables with a new batch of data.

        Args:
            predicted (torch.Tensor):
                A 2D tensor containing the predicted scores for each class.
            actual (torch.Tensor):
                A 1D tensor containing the ground truth labels.
        Returns:
            (NoneType): None

        Examples:
            >>> from spotpython.torch.mapk import MAPK
            >>> import torch
            >>> mapk = MAPK(k=2)
            >>> target = torch.tensor([0, 1, 2, 2])
            >>> preds = torch.tensor(
            ...     [
            ...         [0.5, 0.2, 0.2],  # 0 is in top 2
            ...         [0.3, 0.4, 0.2],  # 1 is in top 2
            ...         [0.2, 0.4, 0.3],  # 2 is in top 2
            ...         [0.7, 0.2, 0.1],  # 2 isn't in top 2
            ...     ]
            ... )
            >>> mapk.update(preds, target)
            >>> print(mapk.compute()) # tensor(0.6250)

        Raises:
            AssertionError: If the actual tensor is not 1D or the predicted tensor is not 2D.
            AssertionError: If the number of elements in the actual and predicted tensors are not equal.

        """
        assert len(actual.shape) == 1, "actual must be a 1D tensor"
        assert len(predicted.shape) == 2, "predicted must be a 2D tensor"
        assert actual.shape[0] == predicted.shape[0], "actual and predicted must have the same number of elements"

        # Convert actual to list of lists
        actual = actual.tolist()
        actual = [[a] for a in actual]

        # Convert predicted to list of lists of indices sorted by confidence score
        _, predicted = predicted.topk(k=self.k, dim=1)
        predicted = predicted.tolist()
        # Code modified according to: "Inplace update to inference tensor outside InferenceMode
        # is not allowed. You can make a clone to get a normal tensor before doing inplace update."
        score = np.mean([self.apk(p, a, self.k) for p, a in zip(predicted, actual)])
        self.total = self.total + score
        self.count = self.count + 1

    def compute(self) -> float:
        """
        Compute the mean average precision at k.

        Args:
            self (MAPK):
                The current instance of the class.

        Returns:
            (float):
                The mean average precision at k.

        Examples:
            >>> evaluator = Evaluator()
            >>> evaluator.total = 3.0
            >>> evaluator.count = 2
            >>> evaluator.compute()
            1.5
        """
        return self.total / self.count

    @staticmethod
    def apk(predicted: List[int], actual: List[int], k: int = 10) -> float:
        """
        Calculate the average precision at k for a single pair of actual and predicted labels.

        Args:
            predicted (list): A list of predicted labels.
            actual (list): A list of ground truth labels.
            k (int): The number of top predictions to consider.

        Returns:
            float: The average precision at k.

        Examples:
            >>> Evaluator.apk([1, 3, 2, 4], [1, 2, 3], 3)
            0.8888888888888888
        """
        if not actual:
            return 0.0

        if len(predicted) > k:
            predicted = predicted[:k]

        score = 0.0
        num_hits = 0.0

        for i, p in enumerate(predicted):
            if p in actual and p not in predicted[:i]:
                num_hits += 1.0
                score += num_hits / (i + 1.0)

        return score / min(len(actual), k)

apk(predicted, actual, k=10) staticmethod

Calculate the average precision at k for a single pair of actual and predicted labels.

Parameters:

Name Type Description Default
predicted list

A list of predicted labels.

required
actual list

A list of ground truth labels.

required
k int

The number of top predictions to consider.

10

Returns:

Name Type Description
float float

The average precision at k.

Examples:

>>> Evaluator.apk([1, 3, 2, 4], [1, 2, 3], 3)
0.8888888888888888
Source code in spotpython/torch/mapk.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@staticmethod
def apk(predicted: List[int], actual: List[int], k: int = 10) -> float:
    """
    Calculate the average precision at k for a single pair of actual and predicted labels.

    Args:
        predicted (list): A list of predicted labels.
        actual (list): A list of ground truth labels.
        k (int): The number of top predictions to consider.

    Returns:
        float: The average precision at k.

    Examples:
        >>> Evaluator.apk([1, 3, 2, 4], [1, 2, 3], 3)
        0.8888888888888888
    """
    if not actual:
        return 0.0

    if len(predicted) > k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i, p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)

    return score / min(len(actual), k)

compute()

Compute the mean average precision at k.

Parameters:

Name Type Description Default
self MAPK

The current instance of the class.

required

Returns:

Type Description
float

The mean average precision at k.

Examples:

>>> evaluator = Evaluator()
>>> evaluator.total = 3.0
>>> evaluator.count = 2
>>> evaluator.compute()
1.5
Source code in spotpython/torch/mapk.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def compute(self) -> float:
    """
    Compute the mean average precision at k.

    Args:
        self (MAPK):
            The current instance of the class.

    Returns:
        (float):
            The mean average precision at k.

    Examples:
        >>> evaluator = Evaluator()
        >>> evaluator.total = 3.0
        >>> evaluator.count = 2
        >>> evaluator.compute()
        1.5
    """
    return self.total / self.count

update(predicted, actual)

Update the state variables with a new batch of data.

Parameters:

Name Type Description Default
predicted Tensor

A 2D tensor containing the predicted scores for each class.

required
actual Tensor

A 1D tensor containing the ground truth labels.

required

Returns: (NoneType): None

Examples:

>>> from spotpython.torch.mapk import MAPK
>>> import torch
>>> mapk = MAPK(k=2)
>>> target = torch.tensor([0, 1, 2, 2])
>>> preds = torch.tensor(
...     [
...         [0.5, 0.2, 0.2],  # 0 is in top 2
...         [0.3, 0.4, 0.2],  # 1 is in top 2
...         [0.2, 0.4, 0.3],  # 2 is in top 2
...         [0.7, 0.2, 0.1],  # 2 isn't in top 2
...     ]
... )
>>> mapk.update(preds, target)
>>> print(mapk.compute()) # tensor(0.6250)

Raises:

Type Description
AssertionError

If the actual tensor is not 1D or the predicted tensor is not 2D.

AssertionError

If the number of elements in the actual and predicted tensors are not equal.

Source code in spotpython/torch/mapk.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def update(self, predicted: torch.Tensor, actual: torch.Tensor):
    """
    Update the state variables with a new batch of data.

    Args:
        predicted (torch.Tensor):
            A 2D tensor containing the predicted scores for each class.
        actual (torch.Tensor):
            A 1D tensor containing the ground truth labels.
    Returns:
        (NoneType): None

    Examples:
        >>> from spotpython.torch.mapk import MAPK
        >>> import torch
        >>> mapk = MAPK(k=2)
        >>> target = torch.tensor([0, 1, 2, 2])
        >>> preds = torch.tensor(
        ...     [
        ...         [0.5, 0.2, 0.2],  # 0 is in top 2
        ...         [0.3, 0.4, 0.2],  # 1 is in top 2
        ...         [0.2, 0.4, 0.3],  # 2 is in top 2
        ...         [0.7, 0.2, 0.1],  # 2 isn't in top 2
        ...     ]
        ... )
        >>> mapk.update(preds, target)
        >>> print(mapk.compute()) # tensor(0.6250)

    Raises:
        AssertionError: If the actual tensor is not 1D or the predicted tensor is not 2D.
        AssertionError: If the number of elements in the actual and predicted tensors are not equal.

    """
    assert len(actual.shape) == 1, "actual must be a 1D tensor"
    assert len(predicted.shape) == 2, "predicted must be a 2D tensor"
    assert actual.shape[0] == predicted.shape[0], "actual and predicted must have the same number of elements"

    # Convert actual to list of lists
    actual = actual.tolist()
    actual = [[a] for a in actual]

    # Convert predicted to list of lists of indices sorted by confidence score
    _, predicted = predicted.topk(k=self.k, dim=1)
    predicted = predicted.tolist()
    # Code modified according to: "Inplace update to inference tensor outside InferenceMode
    # is not allowed. You can make a clone to get a normal tensor before doing inplace update."
    score = np.mean([self.apk(p, a, self.k) for p, a in zip(predicted, actual)])
    self.total = self.total + score
    self.count = self.count + 1