Skip to content

base

Config

Bases: ABC

Base class for all configurations.

All configurations inherit from this class, be they stored in a file or generated on the fly.

Source code in spotriver/data/base.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Config(abc.ABC):
    """Base class for all configurations.

    All configurations inherit from this class, be they stored in a file or generated on the fly.
    """

    def __init__(
        self,
    ):
        pass

    @property
    def desc(self):
        """Return the description from the docstring."""
        desc = re.split(pattern=r"\w+\n\s{4}\-{3,}", string=self.__doc__, maxsplit=0)[0]
        return inspect.cleandoc(desc)

    @property
    def _repr_content(self):
        """The items that are displayed in the __repr__ method.

        This property can be overridden in order to modify the output of the __repr__ method.

        """

        content = {}
        content["Name"] = self.__class__.__name__
        return content

desc property

Return the description from the docstring.

Dataset

Bases: ABC

Base class for all datasets.

All datasets inherit from this class, be they stored in a file or generated on the fly.

Note

The code is based on code from the river package [1] to provide a similar interface.

Parameters:

Name Type Description Default
task str

Type of task the dataset is meant for. Should be one of: - “Regression” - “Binary classification” - “Multi-class classification” - “Multi-output binary classification” - “Multi-output regression”

required
n_features int

Number of features in the dataset.

None
n_samples int

Number of samples in the dataset.

None
n_classes int

Number of classes in the dataset, only applies to classification datasets.

None
n_outputs int

Number of outputs the target is made of, only applies to multi-output datasets.

None
sparse bool

Whether the dataset is sparse or not.

False
References

[1]: Base class for all datasets in River.

Source code in spotriver/data/base.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class Dataset(abc.ABC):
    """Base class for all datasets.

    All datasets inherit from this class, be they stored in a file or generated on the fly.

    Note:
        The code is based on code from the river package [1] to provide a similar interface.

    Args:
        task (str):
            Type of task the dataset is meant for. Should be one of:
            - "Regression"
            - "Binary classification"
            - "Multi-class classification"
            - "Multi-output binary classification"
            - "Multi-output regression"
        n_features (int):
            Number of features in the dataset.
        n_samples (int):
            Number of samples in the dataset.
        n_classes (int):
            Number of classes in the dataset, only applies to classification datasets.
        n_outputs (int):
            Number of outputs the target is made of, only applies to multi-output datasets.
        sparse (bool):
            Whether the dataset is sparse or not.

    References:
        [1]: [Base class for all datasets in River.](https://riverml.xyz/0.18.0/api/datasets/base/Dataset/)

    """

    def __init__(
        self,
        task: str,
        n_features: int = None,
        n_samples=None,
        n_classes=None,
        n_outputs=None,
        sparse=False,
    ):
        self.task = task
        self.n_features = n_features
        self.n_samples = n_samples
        self.n_outputs = n_outputs
        self.n_classes = n_classes
        self.sparse = sparse

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    def take(self, k: int):
        """Iterate over the k samples."""
        return itertools.islice(self, k)

    @property
    def desc(self):
        """Return the description from the docstring."""
        desc = re.split(pattern=r"\w+\n\s{4}\-{3,}", string=self.__doc__, maxsplit=0)[0]
        return inspect.cleandoc(desc)

    @property
    def _repr_content(self):
        """The items that are displayed in the __repr__ method.

        This property can be overridden in order to modify the output of the __repr__ method.

        """

        content = {}
        content["Name"] = self.__class__.__name__
        content["Task"] = self.task
        if isinstance(self, SyntheticDataset) and self.n_samples is None:
            content["Samples"] = "∞"
        elif self.n_samples:
            content["Samples"] = f"{self.n_samples:,}"
        if self.n_features:
            content["Features"] = f"{self.n_features:,}"
        if self.n_outputs:
            content["Outputs"] = f"{self.n_outputs:,}"
        if self.n_classes:
            content["Classes"] = f"{self.n_classes:,}"
        content["Sparse"] = str(self.sparse)

        return content

    def __repr__(self):
        l_len = max(map(len, self._repr_content.keys()))
        r_len = max(map(len, self._repr_content.values()))

        out = f"{self.desc}\n\n" + "\n".join(
            k.rjust(l_len) + "  " + v.ljust(r_len) for k, v in self._repr_content.items()
        )

        if "Parameters\n    ----------" in self.__doc__:
            params = re.split(
                r"\w+\n\s{4}\-{3,}",
                re.split("Parameters\n    ----------", self.__doc__)[1],
            )[0].rstrip()
            out += f"\n\nParameters\n----------{params}"

        return out

desc property

Return the description from the docstring.

take(k)

Iterate over the k samples.

Source code in spotriver/data/base.py
130
131
132
def take(self, k: int):
    """Iterate over the k samples."""
    return itertools.islice(self, k)

FileConfig

Bases: Config

Base class for configurations that are stored in a local file.

Parameters:

Name Type Description Default
filename str

The file’s name.

required
directory str

The directory where the file is contained. Defaults to the location of the datasets module.

None
Source code in spotriver/data/base.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class FileConfig(Config):
    """Base class for configurations that are stored in a local file.

    Args:
        filename (str):
            The file's name.
        directory (str):
            The directory where the file is contained. Defaults to the location of the `datasets` module.
    """

    def __init__(self, filename, directory=None, **desc):
        super().__init__(**desc)
        self.filename = filename
        self.directory = directory

    @property
    def path(self):
        if self.directory:
            return pathlib.Path(self.directory).joinpath(self.filename)
        return pathlib.Path(__file__).parent.joinpath(self.filename)

    @property
    def _repr_content(self):
        content = super()._repr_content
        content["Path"] = str(self.path)
        return content

FileDataset

Bases: Dataset

Base class for datasets that are stored in a local file.

Small datasets that are part of the spotriver package inherit from this class.

Parameters:

Name Type Description Default
filename str

The file’s name.

required
directory str

The directory where the file is contained. Defaults to the location of the datasets module.

None
Source code in spotriver/data/base.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class FileDataset(Dataset):
    """Base class for datasets that are stored in a local file.

    Small datasets that are part of the spotriver package inherit from this class.

    Args:
        filename (str): The file's name.
        directory (str):
            The directory where the file is contained. Defaults to the location of the `datasets` module.

    """

    def __init__(self, filename, directory=None, **desc):
        super().__init__(**desc)
        self.filename = filename
        self.directory = directory

    @property
    def path(self):
        if self.directory:
            return pathlib.Path(self.directory).joinpath(self.filename)
        return pathlib.Path(__file__).parent.joinpath(self.filename)

    @property
    def _repr_content(self):
        content = super()._repr_content
        content["Path"] = str(self.path)
        return content

GenericFileDataset

Bases: Dataset

Base class for datasets that are stored in a local file.

Small datasets that are part of the spotriver package inherit from this class.

Parameters:

Name Type Description Default
filename str

The file’s name.

required
directory str

The directory where the file is contained. Defaults to the location of the datasets module.

None
Source code in spotriver/data/base.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
class GenericFileDataset(Dataset):
    """Base class for datasets that are stored in a local file.

    Small datasets that are part of the spotriver package inherit from this class.

    Args:
        filename (str): The file's name.
        directory (str):
            The directory where the file is contained. Defaults to the location of the `datasets` module.
    """

    def __init__(self, filename, target, converters, parse_dates, directory=None, **desc):
        super().__init__(**desc)
        self.filename = filename
        self.directory = directory
        self.target = target
        self.converters = converters
        self.parse_dates = parse_dates

    @property
    def path(self):
        if self.directory:
            return pathlib.Path(self.directory).joinpath(self.filename)
        return pathlib.Path(__file__).parent.joinpath(self.filename)

    @property
    def _repr_content(self):
        content = super()._repr_content
        content["Path"] = str(self.path)
        return content

RemoteDataset

Bases: FileDataset

Base class for datasets that are stored in a remote file.

Medium and large datasets that are not part of the river package inherit from this class.

Note

The filename doesn’t have to be provided if unpack is False. Indeed in the latter case the filename will be inferred from the URL.

Parameters:

Name Type Description Default
url str

The URL the dataset is located at.

required
size int

The expected download size.

required
unpack bool

Whether to unpack the download or not.

True
filename str

An optional name to given to the file if the file is unpacked.

None
Source code in spotriver/data/base.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
class RemoteDataset(FileDataset):
    """Base class for datasets that are stored in a remote file.

    Medium and large datasets that are not part of the river package inherit from this class.

    Note:
        The filename doesn't have to be provided if unpack is False. Indeed in the latter case the
        filename will be inferred from the URL.

    Args:
        url (str):
            The URL the dataset is located at.
        size (int):
            The expected download size.
        unpack (bool):
            Whether to unpack the download or not.
        filename (str):
            An optional name to given to the file if the file is unpacked.
    """

    def __init__(self, url, size, unpack=True, filename=None, **desc):
        if filename is None:
            filename = path.basename(url)

        super().__init__(filename=filename, **desc)
        self.url = url
        self.size = size
        self.unpack = unpack

    @property
    def path(self):
        return pathlib.Path(get_data_home(), self.__class__.__name__, self.filename)

    def download(self, force=False, verbose=True):
        if not force and self.is_downloaded:
            return

        # Determine where to download the archive
        directory = self.path.parent
        directory.mkdir(parents=True, exist_ok=True)
        archive_path = directory.joinpath(path.basename(self.url))

        with request.urlopen(self.url) as r:
            # Notify the user
            if verbose:
                meta = r.info()
                try:
                    n_bytes = int(meta["Content-Length"])
                    msg = f"Downloading {self.url} ({utils.pretty.humanize_bytes(n_bytes)})"
                except KeyError:
                    msg = f"Downloading {self.url}"
                print(msg)

            # Now dump the contents of the requests
            with open(archive_path, "wb") as f:
                shutil.copyfileobj(r, f)

        if not self.unpack:
            return

        if verbose:
            print(f"Uncompressing into {directory}")

        if archive_path.suffix.endswith("zip"):
            with zipfile.ZipFile(archive_path, "r") as zf:
                zf.extractall(directory)

        elif archive_path.suffix.endswith(("gz", "tar")):
            mode = "r:" if archive_path.suffix.endswith("tar") else "r:gz"
            tar = tarfile.open(archive_path, mode)
            tar.extractall(directory)
            tar.close()

        else:
            raise RuntimeError(f"Unhandled extension type: {archive_path.suffix}")

        # Delete the archive file now that it has been uncompressed
        archive_path.unlink()

    @abc.abstractmethod
    def _iter(self):
        pass

    @property
    def is_downloaded(self):
        """Indicate whether or the data has been correctly downloaded."""
        if self.path.exists():
            if self.path.is_file():
                return self.path.stat().st_size == self.size
            return sum(f.stat().st_size for f in self.path.glob("**/*") if f.is_file())

        return False

    def __iter__(self):
        if not self.is_downloaded:
            self.download(verbose=True)
        if not self.is_downloaded:
            raise RuntimeError("Something went wrong during the download")
        yield from self._iter()

    @property
    def _repr_content(self):
        content = super()._repr_content
        content["URL"] = self.url
        content["Size"] = utils.pretty.humanize_bytes(self.size)
        content["Downloaded"] = str(self.is_downloaded)
        return content

is_downloaded property

Indicate whether or the data has been correctly downloaded.

SyntheticDataset

Bases: Dataset

A synthetic dataset.

All synthetic datasets inherit from this class.

Source code in spotriver/data/base.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class SyntheticDataset(Dataset):
    """A synthetic dataset.

    All synthetic datasets inherit from this class.

    """

    def __repr__(self):
        l_len_prop = max(map(len, self._repr_content.keys()))
        r_len_prop = max(map(len, self._repr_content.values()))
        params = self._get_params()
        l_len_config = max(map(len, params.keys()))
        r_len_config = max(map(len, map(str, params.values())))

        out = (
            "Synthetic data generator\n\n"
            + "\n".join(k.rjust(l_len_prop) + "  " + v.ljust(r_len_prop) for k, v in self._repr_content.items())
            + "\n\nConfiguration\n-------------\n"
            + "\n".join(k.rjust(l_len_config) + "  " + str(v).ljust(r_len_config) for k, v in params.items())
        )

        return out

    def _get_params(self) -> typing.Dict[str, typing.Any]:
        """Return the parameters that were used during initialization."""
        return {
            name: getattr(self, name)
            for name, param in inspect.signature(self.__init__).parameters.items()  # type: ignore
            if param.kind != param.VAR_KEYWORD
        }

get_data_home(data_home=None)

Return the location where remote datasets are to be stored. By default the data directory is set to a folder named ‘spotriver_data’ in the user home folder. Alternatively, it can be set by the ‘SPOTRIVER_DATA’ environment variable or programmatically by giving an explicit folder path. The ‘~’ symbol is expanded to the user home folder. If the folder does not already exist, it is automatically created.

Parameters:

Name Type Description Default
data_home str

The path to spotriver data directory. If None, the default path is ~/spotriver_data.

None

Returns:

Name Type Description
data_home str

The path to the spotriver data directory.

Source code in spotriver/data/base.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_data_home(data_home=None) -> str:
    """Return the location where remote datasets are to be stored.
        By default the data directory is set to a folder named 'spotriver_data' in the
        user home folder. Alternatively, it can be set by the 'SPOTRIVER_DATA' environment
        variable or programmatically by giving an explicit folder path. The '~'
        symbol is expanded to the user home folder.
        If the folder does not already exist, it is automatically created.

    Args:
        data_home (str): The path to spotriver data directory. If `None`, the default path is `~/spotriver_data`.

    Returns:
        data_home (str): The path to the spotriver data directory.
    """
    if data_home is None:
        data_home = environ.get("SPOTRIVER_DATA", Path.home() / "spotriver_data")
    # Ensure data_home is a Path() object pointing to an absolute path
    data_home = Path(data_home).absolute()
    # Create data directory if it does not exists.
    data_home.mkdir(parents=True, exist_ok=True)
    return data_home