tsio.py 17.6 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
5
6

import pandas as pd

7
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
8
from sqlalchemy.sql.expression import select, func, desc
9
from sqlalchemy.dialects.postgresql import BYTEA
10

11
from tshistory.schema import tsschema
12
13
from tshistory.util import (
    inject_in_index,
14
    num2float,
15
    subset,
16
    SeriesServices,
17
18
    tzaware_serie
)
19
from tshistory.snapshot import Snapshot
20

21
L = logging.getLogger('tshistory.tsio')
22
TABLES = {}
23
24


25
class TimeSerie(SeriesServices):
26
    namespace = 'tsh'
27
    schema = None
28
29
30

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
31
32
        self.schema = tsschema(namespace)
        self.schema.define()
33
        self.metadatacache = {}
34

35
36
37
    def insert(self, cn, newts, name, author,
               metadata=None,
               _insertion_date=None):
38
        """Create a new revision of a given time series
39

40
        newts: pandas.Series with date index
41
        name: str unique identifier of the serie
42
        author: str free-form author name
43
        metadata: optional dict for changeset metadata
44
45
        """
        assert isinstance(newts, pd.Series)
Aurélien Campéas's avatar
Aurélien Campéas committed
46
47
        assert isinstance(name, str)
        assert isinstance(author, str)
48
        assert metadata is None or isinstance(metadata, dict)
Aurélien Campéas's avatar
Aurélien Campéas committed
49
        assert _insertion_date is None or isinstance(_insertion_date, datetime)
50
        assert not newts.index.duplicated().any()
51

52
        newts = num2float(newts)
53

54
        if not len(newts):
55
            return
56

57
        assert ('<M8[ns]' == newts.index.dtype or
58
                'datetime' in str(newts.index.dtype) and not
59
60
                isinstance(newts.index, pd.MultiIndex))

61
        newts.name = name
62
        table = self._get_ts_table(cn, name)
63

64
        if table is None:
65
66
            return self._create(cn, newts, name, author,
                                metadata, _insertion_date)
67

68
69
        return self._update(cn, table, newts, name, author,
                            metadata, _insertion_date)
70

71
    def get(self, cn, name, revision_date=None,
72
73
            from_value_date=None, to_value_date=None,
            _keep_nans=False):
74
75
76
77
78
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

79
        """
80
        table = self._get_ts_table(cn, name)
81
82
        if table is None:
            return
83

84
85
86
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
87
        snap = Snapshot(cn, self, name)
88
89
90
        _, current = snap.find(qfilter,
                               from_value_date=from_value_date,
                               to_value_date=to_value_date)
91

92
        if current is not None and not _keep_nans:
93
            current.name = name
94
            current = current[~current.isnull()]
95
        return current
96

97
    def metadata(self, cn, tsname):
98
99
100
101
102
103
104
105
106
107
108
        """Return metadata dict of timeserie."""
        if tsname in self.metadatacache:
            return self.metadatacache[tsname]
        reg = self.schema.registry
        sql = select([reg.c.metadata]).where(
            reg.c.name == tsname
        )
        meta = cn.execute(sql).scalar()
        self.metadatacache[tsname] = meta
        return meta

109
110
111
112
113
114
115
116
    def changeset_metadata(self, cn, csid):
        cset = self.schema.changeset
        sql = 'select metadata from "{ns}".changeset where id = {id}'.format(
            ns=self.namespace,
            id=csid
        )
        return cn.execute(sql).scalar()

117
118
    def get_history(self, cn, name,
                    from_insertion_date=None,
119
120
                    to_insertion_date=None,
                    from_value_date=None,
121
                    to_value_date=None,
122
123
                    deltabefore=None,
                    deltaafter=None,
124
                    diffmode=False):
125
126
127
128
        table = self._get_ts_table(cn, name)
        if table is None:
            return

129
130
131
132
133
        if deltabefore is not None or deltaafter is not None:
            assert diffmode is False
            assert from_value_date is None
            assert to_value_date is None

134
        cset = self.schema.changeset
135

136
        if diffmode:
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            # compute diffs above the snapshot
            diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
            ).order_by(cset.c.id
            ).where(table.c.cset == cset.c.id)

            if from_insertion_date:
                diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
            if to_insertion_date:
                diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

            diffs = cn.execute(diffsql).fetchall()
            if not diffs:
                # it's fine to ask for an insertion date range
                # where noting did happen, but you get nothing
                return

153
            snapshot = Snapshot(cn, self, name)
154
155
156
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
157
158
159
160
                    serie = subset(snapshot.first, from_value_date, to_value_date)
                else:
                    serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
                    serie = self._ensure_tz_consistency(cn, serie)
161
                inject_in_index(serie, revdate)
162
163
164
165
166
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

167
168
169
170
171
172
173
174
175
176
177
178
        revsql = select(
            [cset.c.id, cset.c.insertion_date]
        ).order_by(
            cset.c.id
        ).where(
            table.c.cset == cset.c.id
        )

        if from_insertion_date:
            revsql = revsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            revsql = revsql.where(cset.c.insertion_date <= to_insertion_date)
179

180
181
182
        revs = cn.execute(revsql).fetchall()
        if not revs:
            return
183

184
185
186
        snapshot = Snapshot(cn, self, name)
        series = []
        for csid, idate in revs:
187
188
189
190
191
192
193
            if deltabefore or deltaafter:
                from_value_date = idate
                to_value_date = idate
                if deltabefore:
                    from_value_date = idate - deltabefore
                if deltaafter:
                    to_value_date = idate + deltaafter
194
195
196
197
198
199
            series.append((
                idate,
                snapshot.find([lambda cset, _: cset.c.id == csid],
                               from_value_date=from_value_date,
                               to_value_date=to_value_date)[1]
            ))
200
201

        for revdate, serie in series:
202
            inject_in_index(serie, revdate)
203
204
205
206

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
207

208
209
210
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

211
    def latest_insertion_date(self, cn, name):
212
        cset = self.schema.changeset
213
        tstable = self._get_ts_table(cn, name)
214
        sql = select([func.max(cset.c.insertion_date)]
215
        ).where(tstable.c.cset == cset.c.id)
216
        return cn.execute(sql).scalar()
217

218
219
220
221
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
222
223
        sql = select([table.c.cset]).where(
            table.c.cset == cset.c.id
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
242
            metadata = self.changeset_metadata(cn, log['rev']) or {}
243
244
245
246
247
248
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
249
                cset_serie.c.cset == log['rev']
250
251
252
253
254
255
256
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
257
        cn.execute(table.delete().where(table.c.cset >= csid))
258

259
    def info(self, cn):
260
261
        """Gather global statistics on the current tshistory repository
        """
262
        sql = 'select count(*) from {}.registry'.format(self.namespace)
263
        stats = {'series count': cn.execute(sql).scalar()}
264
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
265
        stats['changeset count'] = cn.execute(sql).scalar()
266
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
267
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
268
269
        return stats

270
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
271
            stripped=False,
272
273
            fromrev=None, torev=None,
            fromdate=None, todate=None):
274
275
276
277
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
278
279
280
281
282
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
283

284
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
285
        ).distinct().order_by(desc(cset.c.id))
286
287
288

        if limit:
            sql = sql.limit(limit)
289
290
        if names:
            sql = sql.where(reg.c.name.in_(names))
291
292
        if authors:
            sql = sql.where(cset.c.author.in_(authors))
293
294
295
296
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)
        if torev:
            sql = sql.where(cset.c.id <= torev)
297
298
299
300
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)
        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)
301
302
303
304
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
305
            sql = sql.where(cset.c.id == cset_series.c.cset
306
            ).where(cset_series.c.serie == reg.c.name)
307

308
        rset = cn.execute(sql)
309
        for csetid, author, revdate, meta in rset.fetchall():
310
311
            log.append({'rev': csetid, 'author': author,
                        'date': pd.Timestamp(revdate, tz='utc'),
312
                        'meta': meta or {},
313
                        'names': self._changeset_series(cn, csetid)})
314
315
316

        if diff:
            for rev in log:
317
                rev['diff'] = {name: self.diff_at(cn, rev['rev'], name)
318
319
                               for name in rev['names']}

320
        log.sort(key=lambda rev: rev['rev'])
321
322
        return log

323
324
    # /API
    # Helpers
325

326
327
    # creation / update

328
329
    def _create(self, cn, newts, name, author,
                metadata=None, insertion_date=None):
330
331
332
333
        # initial insertion
        if len(newts) == 0:
            return None
        snapshot = Snapshot(cn, self, name)
334
        csid = self._newchangeset(cn, author, insertion_date, metadata)
335
336
337
338
339
340
341
342
343
344
345
346
        head = snapshot.create(newts)
        value = {
            'cset': csid,
            'snapshot': head
        }
        table = self._make_ts_table(cn, name, newts)
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
        L.info('first insertion of %s (size=%s) by %s',
               name, len(newts), author or self._author)
        return newts

347
348
    def _update(self, cn, table, newts, name, author,
                metadata=None, insertion_date=None):
349
350
        self._validate(cn, newts, name)
        snapshot = Snapshot(cn, self, name)
351
352
353
        diff = self.diff(snapshot.last(newts.index.min(),
                                       newts.index.max()),
                         newts)
354
355
356
357
358
        if not len(diff):
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
            return

359
        csid = self._newchangeset(cn, author, insertion_date, metadata)
360
361
362
363
364
365
366
367
368
369
370
371
372
        head = snapshot.update(diff)
        value = {
            'cset': csid,
            'diff': self._serialize(diff),
            'snapshot': head
        }
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)

        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
        return diff

373
374
    # ts serialisation

375
376
377
378
379
    def _ensure_tz_consistency(self, cn, ts):
        """Return timeserie with tz aware index or not depending on metadata
        tzaware.
        """
        assert ts.name is not None
380
        metadata = self.metadata(cn, ts.name)
381
382
383
384
        if metadata and metadata.get('tzaware', False):
            return ts.tz_localize('UTC')
        return ts

385
    # serie table handling
386

387
388
    def _ts_table_name(self, seriename):
        # namespace.seriename
389
        return '{}.timeserie.{}'.format(self.namespace, seriename)
390

391
    def _table_definition_for(self, seriename):
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        tablename = self._ts_table_name(seriename)
        table = TABLES.get(tablename)
        if table is None:
            TABLES[tablename] = table = Table(
                seriename, self.schema.meta,
                Column('id', Integer, primary_key=True),
                Column('cset', Integer,
                       ForeignKey('{}.changeset.id'.format(self.namespace)),
                       index=True, nullable=False),
                Column('diff', BYTEA),
                Column('snapshot', Integer,
                       ForeignKey('{}.snapshot.{}.id'.format(
                           self.namespace,
                           seriename)),
                       index=True),
                schema='{}.timeserie'.format(self.namespace),
                extend_existing=True
            )
        return table
411

412
    def _make_ts_table(self, cn, name, ts):
413
        tablename = self._ts_table_name(name)
414
        table = self._table_definition_for(name)
415
        table.create(cn)
416
417
        index = ts.index
        inames = [name for name in index.names if name]
418
        sql = self.schema.registry.insert().values(
419
            name=name,
420
            table_name=tablename,
421
422
423
424
425
426
            metadata={
                'tzaware': tzaware_serie(ts),
                'index_type': index.dtype.name,
                'index_names': inames,
                'value_type': ts.dtypes.name
            },
427
        )
428
        cn.execute(sql)
429
430
        return table

431
    def _get_ts_table(self, cn, name):
432
        reg = self.schema.registry
433
434
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
435
        tid = cn.execute(sql).scalar()
436
        if tid:
437
            return self._table_definition_for(name)
438

439
440
    # changeset handling

441
    def _newchangeset(self, cn, author, insertion_date=None, metadata=None):
442
        table = self.schema.changeset
443
444
445
        if insertion_date is not None:
            assert insertion_date.tzinfo is not None
        idate = pd.Timestamp(insertion_date or datetime.utcnow(), tz='UTC')
446
447
        sql = table.insert().values(
            author=author,
448
            metadata=metadata,
449
            insertion_date=idate)
450
        return cn.execute(sql).inserted_primary_key[0]
451

452
    def _changeset_series(self, cn, csid):
453
        cset_serie = self.schema.changeset_series
454
        sql = select([cset_serie.c.serie]
455
        ).where(cset_serie.c.cset == csid)
456

457
        return [seriename for seriename, in cn.execute(sql).fetchall()]
458
459
460

    # insertion handling

461
    def _validate(self, cn, ts, name):
462
463
        if ts.isnull().all():
            # ts erasure
464
            return
465
        tstype = ts.dtype
466
        meta = self.metadata(cn, name)
467
        if tstype != meta['value_type']:
468
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
469
                name, tstype, meta['value_type'])
470
            raise Exception(m)
471
        if ts.index.dtype.name != meta['index_type']:
472
            raise Exception('Incompatible index types')
473

474
475
476
477
478
479
480
    def _finalize_insertion(self, cn, csid, name):
        table = self.schema.changeset_series
        sql = table.insert().values(
            cset=csid,
            serie=name
        )
        cn.execute(sql)
481

482
    def diff_at(self, cn, csetid, name):
483
        table = self._get_ts_table(cn, name)
484
        cset = self.schema.changeset
485
486

        def filtercset(sql):
487
            return sql.where(table.c.cset == cset.c.id
488
489
490
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
491
        tsid = cn.execute(sql).scalar()
492
493

        if tsid == 1:
494
            return Snapshot(cn, self, name).first
495

496
        sql = filtercset(select([table.c.diff]))
497
498
        ts = self._deserialize(cn.execute(sql).scalar(), name)
        return self._ensure_tz_consistency(cn, ts)