litmodel
LitModel
¶
Bases: LightningModule
A LightningModule class for a simple neural network model.
Attributes:
Name | Type | Description |
---|---|---|
l1 |
int
|
The number of neurons in the first hidden layer. |
epochs |
int
|
The number of epochs to train the model for. |
batch_size |
int
|
The batch size to use during training. |
act_fn |
str
|
The activation function to use in the hidden layers. |
optimizer |
str
|
The optimizer to use during training. |
learning_rate |
float
|
The learning rate for the optimizer. |
_L_in |
int
|
The number of input features. |
_L_out |
int
|
The number of output classes. |
model |
Sequential
|
The neural network model. |
Examples:
>>> from torch.utils.data import DataLoader
>>> from torchvision.datasets import MNIST
>>> from torchvision.transforms import ToTensor
>>> train_data = MNIST(PATH_DATASETS, train=True, download=True, transform=ToTensor())
>>> train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)
>>> lit_model = LitModel(l1=128, epochs=10, batch_size=BATCH_SIZE, act_fn='relu', optimizer='adam')
>>> trainer = L.Trainer(max_epochs=10)
>>> trainer.fit(lit_model, train_loader)
Source code in spotpython/light/litmodel.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
|
__init__(l1, epochs, batch_size, act_fn, optimizer, learning_rate=0.0002, _L_in=28 * 28, _L_out=10, *args, **kwargs)
¶
Initializes the LitModel object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
l1 |
int
|
The number of neurons in the first hidden layer. |
required |
epochs |
int
|
The number of epochs to train the model for. |
required |
batch_size |
int
|
The batch size to use during training. |
required |
act_fn |
str
|
The activation function to use in the hidden layers. |
required |
optimizer |
str
|
The optimizer to use during training. |
required |
learning_rate |
float
|
The learning rate for the optimizer. Defaults to 2e-4. |
0.0002
|
_L_in |
int
|
The number of input features. Defaults to 28 * 28. |
28 * 28
|
_L_out |
int
|
The number of output classes. Defaults to 10. |
10
|
Returns:
Type | Description |
---|---|
NoneType
|
None |
Source code in spotpython/light/litmodel.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
configure_optimizers()
¶
Configures the optimizer for the model.
Returns:
Type | Description |
---|---|
Optimizer
|
torch.optim.Optimizer: The optimizer to use during training. |
Source code in spotpython/light/litmodel.py
158 159 160 161 162 163 164 165 166 |
|
forward(x)
¶
Performs a forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor
|
A tensor containing a batch of input data. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: A tensor containing the log probabilities for each class. |
Source code in spotpython/light/litmodel.py
91 92 93 94 95 96 97 98 99 100 101 102 |
|
test_step(batch, batch_idx)
¶
Performs a single test step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
tuple
|
A tuple containing a batch of input data and labels. |
required |
batch_idx |
int
|
The index of the current batch. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
A tuple containing the loss and accuracy for this batch. |
Source code in spotpython/light/litmodel.py
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
training_step(batch)
¶
Performs a single training step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
tuple
|
A tuple containing a batch of input data and labels. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: A tensor containing the loss for this batch. |
Source code in spotpython/light/litmodel.py
104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
|
validation_step(batch, batch_idx)
¶
Performs a single validation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
tuple
|
A tuple containing a batch of input data and labels. |
required |
batch_idx |
int
|
The index of the current batch. |
required |
Returns:
Type | Description |
---|---|
None
|
None |
Source code in spotpython/light/litmodel.py
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
|