Skip to content

torch_hyper_dict

TorchHyperDict

Bases: FileConfig

PyTorch hyperparameter dictionary.

This class extends the FileConfig class to provide a dictionary for storing hyperparameters.

Attributes:

Name Type Description
filename str

The name of the file where the hyperparameters are stored.

Source code in spotPython/hyperdict/torch_hyper_dict.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class TorchHyperDict(base.FileConfig):
    """PyTorch hyperparameter dictionary.

    This class extends the FileConfig class to provide a dictionary for storing hyperparameters.

    Attributes:
        filename (str): The name of the file where the hyperparameters are stored.
    """

    def __init__(
        self,
        filename: str = "torch_hyper_dict.json",
        directory: None = None,
    ) -> None:
        super().__init__(filename=filename, directory=directory)
        self.filename = filename
        self.directory = directory
        self.hyper_dict = self.load()

    @property
    def path(self):
        if self.directory:
            return pathlib.Path(self.directory).joinpath(self.filename)
        return pathlib.Path(__file__).parent.joinpath(self.filename)

    def load(self) -> dict:
        """Load the hyperparameters from the file.

        Returns:
            (dict): A dictionary containing the hyperparameters.
        Examples:
            >>> thd = TorchHyperDict()
            >>> hyperparams = thd.load()
            >>> print(hyperparams)
            {'learning_rate': 0.001, 'batch_size': 32, 'epochs': 10}
        """
        with open(self.path, "r") as f:
            d = json.load(f)
        return d

load()

Load the hyperparameters from the file.

Returns:

Type Description
dict

A dictionary containing the hyperparameters.

Examples: >>> thd = TorchHyperDict() >>> hyperparams = thd.load() >>> print(hyperparams) {‘learning_rate’: 0.001, ‘batch_size’: 32, ‘epochs’: 10}

Source code in spotPython/hyperdict/torch_hyper_dict.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def load(self) -> dict:
    """Load the hyperparameters from the file.

    Returns:
        (dict): A dictionary containing the hyperparameters.
    Examples:
        >>> thd = TorchHyperDict()
        >>> hyperparams = thd.load()
        >>> print(hyperparams)
        {'learning_rate': 0.001, 'batch_size': 32, 'epochs': 10}
    """
    with open(self.path, "r") as f:
        d = json.load(f)
    return d