tsio.py 23.1 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
import zlib
5
import math
6
7
8
9

import pandas as pd
import numpy as np

10
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
11
from sqlalchemy.sql.expression import select, func, desc
12
from sqlalchemy.dialects.postgresql import BYTEA
13

14
from tshistory.schema import SCHEMAS
15
16


17
L = logging.getLogger('tshistory.tsio')
18
19


20
def tojson(ts):
21
    if not isinstance(ts.index, pd.MultiIndex):
22
23
        return ts.to_json(date_format='iso',
                          double_precision=-int(math.log10(TimeSerie._precision)))
24

25
26
27
    # multi index case
    return ts.to_frame().reset_index().to_json(date_format='iso')

Aurélien Campéas's avatar
Aurélien Campéas committed
28

29
30
31
32
33
def num2float(pdobj):
    # get a Series or a Dataframe column
    if str(pdobj.dtype).startswith('int'):
        return pdobj.astype('float64')
    return pdobj
34

Aurélien Campéas's avatar
Aurélien Campéas committed
35

36
def fromjson(jsonb, tsname):
37
38
39
40
    return _fromjson(jsonb, tsname).fillna(value=np.nan)


def _fromjson(jsonb, tsname):
41
42
43
    if jsonb == '{}':
        return pd.Series(name=tsname)

44
45
    result = pd.read_json(jsonb, typ='series', dtype=False)
    if isinstance(result.index, pd.DatetimeIndex):
46
        result = num2float(result)
47
48
49
50
51
52
53
54
        return result

    # multi index case
    columns = result.index.values.tolist()
    columns.remove(tsname)
    result = pd.read_json(jsonb, typ='frame',
                          convert_dates=columns)
    result.set_index(sorted(columns), inplace=True)
Aurélien Campéas's avatar
Aurélien Campéas committed
55
    return num2float(result.iloc[:, 0])  # get a Series object
56
57


58
59
60
61
62
63
def subset(ts, fromdate, todate):
    if fromdate is None and todate is None:
        return ts
    return ts.loc[fromdate:todate]


64
class TimeSerie(object):
65
    _csid = None
66
    _snapshot_interval = 10
67
    _precision = 1e-14
68
    namespace = 'tsh'
69
    schema = None
70
71
72

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
73
        self.schema = SCHEMAS[namespace]
74
75
76

    # API : changeset, insert, get, delete
    @contextmanager
77
    def newchangeset(self, cn, author, _insertion_date=None):
78
79
80
81
82
83
        """A context manager to allow insertion of several series within the
        same changeset identifier

        This allows to group changes to several series, hence
        producing a macro-change.

84
85
        _insertion_date is *only* provided for migration purposes and
        not part of the API.
86
        """
87
        assert self._csid is None
88
        self._csid = self._newchangeset(cn, author, _insertion_date)
89
        self._author = author
90
91
        yield
        del self._csid
92
        del self._author
93

94
    def insert(self, cn, newts, name, author=None,
95
               extra_scalars={}):
96
        """Create a new revision of a given time series
97

98
        newts: pandas.Series with date index
99

100
        name: str unique identifier of the serie
101
102
103
104

        author: str free-form author name (mandatory, unless provided
        to the newchangeset context manager).

105
        """
106
107
        assert self._csid or author, 'author is mandatory'
        if self._csid and author:
108
109
            L.info('author r{} will not be used when in a changeset'.format(author))
            author = None
110
        assert isinstance(newts, pd.Series)
111
        assert not newts.index.duplicated().any()
112

113
        newts = num2float(newts)
114

115
        if not len(newts):
116
            return
117

118
119
120
121
        assert ('<M8[ns]' == newts.index.dtype or
                'datetime' in str(newts.index.dtype) or
                isinstance(newts.index, pd.MultiIndex))

122
        newts.name = name
123
        table = self._get_ts_table(cn, name)
124

125
126
127
128
        if isinstance(newts.index, pd.MultiIndex):
            # we impose an order to survive rountrips
            newts = newts.reorder_levels(sorted(newts.index.names))

129
130
        if table is None:
            # initial insertion
131
132
            if newts.isnull().all():
                return None
133
            newts = newts[~newts.isnull()]
134
135
            table = self._make_ts_table(cn, name)
            csid = self._csid or self._newchangeset(cn, author)
136
            value = {
137
                'csid': csid,
138
                'snapshot': self._serialize(newts),
139
            }
140
141
            # callback for extenders
            self._complete_insertion_value(value, extra_scalars)
142
143
            cn.execute(table.insert().values(value))
            self._finalize_insertion(cn, csid, name)
144
145
            L.info('first insertion of %s (size=%s) by %s',
                   name, len(newts), author or self._author)
146
            return newts
147

148
        diff, newsnapshot = self._compute_diff_and_newsnapshot(
149
            cn, table, newts, **extra_scalars
150
151
        )
        if diff is None:
152
153
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
154
155
            return

156
157
        tip_id = self._get_tip_id(cn, table)
        csid = self._csid or self._newchangeset(cn, author)
158
        value = {
159
            'csid': csid,
160
161
            'diff': self._serialize(diff),
            'snapshot': self._serialize(newsnapshot),
162
163
164
165
            'parent': tip_id,
        }
        # callback for extenders
        self._complete_insertion_value(value, extra_scalars)
166
167
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
168

169
        if tip_id > 1 and tip_id % self._snapshot_interval:
170
            self._purge_snapshot_at(cn, table, tip_id)
171
172
        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
173
        return diff
174

175
176
    def get(self, cn, name, revision_date=None,
            from_value_date=None, to_value_date=None):
177
178
179
180
181
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

182
        """
183
        table = self._get_ts_table(cn, name)
184
185
        if table is None:
            return
186

187
188
189
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
190
191
192
        current = self._build_snapshot_upto(cn, table, qfilter,
                                            from_value_date=from_value_date,
                                            to_value_date=to_value_date)
193

194
195
        if current is not None:
            current.name = name
196
            current = current[~current.isnull()]
197
        return current
198

199
200
    def get_group(self, cn, name, revision_date=None):
        csid = self._latest_csid_for(cn, name)
201
202

        group = {}
203
204
        for seriename in self._changeset_series(cn, csid):
            serie = self.get(cn, seriename, revision_date)
205
206
207
208
            if serie is not None:
                group[seriename] = serie
        return group

209
210
    def get_history(self, cn, name,
                    from_insertion_date=None,
211
212
                    to_insertion_date=None,
                    from_value_date=None,
213
214
                    to_value_date=None,
                    diffmode=False):
215
216
217
218
        table = self._get_ts_table(cn, name)
        if table is None:
            return

219
        # compute diffs above the snapshot
220
221
222
223
224
225
226
227
228
229
230
        cset = self.schema.changeset
        diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
        ).order_by(cset.c.id
        ).where(table.c.csid == cset.c.id)

        if from_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

        diffs = cn.execute(diffsql).fetchall()
231
232
233
234
235
        if not diffs:
            # it's fine to ask for an insertion date range
            # where noting did happen, but you get nothing
            return

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        if diffmode:
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
                    sql = select([table.c.snapshot]).where(table.c.csid == csid)
                    diff = cn.execute(sql).scalar()
                serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
                mindex = [(revdate, valuestamp) for valuestamp in serie.index]
                serie.index = pd.MultiIndex.from_tuples(mindex, names=[
                    'insertion_date', 'value_date']
                )
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

252
253
        csid, revdate, diff_ = diffs[0]
        snapshot = self._build_snapshot_upto(cn, table, [
254
            lambda cset, _: cset.c.id <= csid
255
        ], from_value_date, to_value_date)
256

257
        series = [(revdate, subset(snapshot, from_value_date, to_value_date))]
258
        for csid_, revdate, diff in diffs[1:]:
259
260
            diff = subset(self._deserialize(diff, table.name),
                          from_value_date, to_value_date)
261

262
263
264
265
            serie = self._apply_diff(series[-1][1], diff)
            series.append((revdate, serie))

        for revdate, serie in series:
266
            mindex = [(revdate, valuestamp) for valuestamp in serie.index]
267
268
269
270
271
272
273
            serie.index = pd.MultiIndex.from_tuples(mindex, names=[
                'insertion_date', 'value_date']
            )

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
274

275
276
277
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

278
    def latest_insertion_date(self, cn, name):
279
        cset = self.schema.changeset
280
        tstable = self._get_ts_table(cn, name)
281
282
        sql = select([func.max(cset.c.insertion_date)]
        ).where(tstable.c.csid == cset.c.id)
283
        return cn.execute(sql).scalar()
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
        sql = select([table.c.csid]).where(
            table.c.csid == cset.c.id
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
            metadata = cn.execute(
                select([cset.c.metadata]).where(cset.c.id == log['rev'])
            ).scalar() or {}
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
                cset_serie.c.csid == log['rev']
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
        cn.execute(table.delete().where(table.c.csid == csid))
        # rebuild the top-level snapshot
        cstip = self._latest_csid_for(cn, seriename)
        if cn.execute(select([table.c.snapshot]).where(table.c.csid == cstip)).scalar() is None:
            snap = self._build_snapshot_upto(
                cn, table,
                qfilter=(lambda cset, _t: cset.c.id < csid,)
            )
            sql = table.update().where(
                table.c.csid == cstip
            ).values(
                snapshot=self._serialize(snap)
            )
            cn.execute(sql)

341
    def info(self, cn):
342
343
        """Gather global statistics on the current tshistory repository
        """
344
        sql = 'select count(*) from {}.registry'.format(self.namespace)
345
        stats = {'series count': cn.execute(sql).scalar()}
346
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
347
        stats['changeset count'] = cn.execute(sql).scalar()
348
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
349
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
350
351
        return stats

352
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
353
            stripped=False,
354
355
            fromrev=None, torev=None,
            fromdate=None, todate=None):
356
357
358
359
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
360
361
362
363
364
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
365

366
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
367
        ).distinct().order_by(desc(cset.c.id))
368
369
370
371

        if limit:
            sql = sql.limit(limit)

372
373
374
        if names:
            sql = sql.where(reg.c.name.in_(names))

375
376
377
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

378
379
380
381
382
383
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

384
385
386
387
388
389
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

390
391
392
393
394
395
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
            sql = sql.where(cset.c.id == cset_series.c.csid
            ).where(cset_series.c.serie == reg.c.name)
396

397
        rset = cn.execute(sql)
398
        for csetid, author, revdate, meta in rset.fetchall():
399
            log.append({'rev': csetid, 'author': author, 'date': revdate,
400
                        'meta': meta or {},
401
                        'names': self._changeset_series(cn, csetid)})
402
403
404

        if diff:
            for rev in log:
405
                rev['diff'] = {name: self._diff(cn, rev['rev'], name)
406
407
                               for name in rev['names']}

408
        log.sort(key=lambda rev: rev['rev'])
409
410
        return log

411
412
    # /API
    # Helpers
413

414
415
416
    # ts serialisation

    def _serialize(self, ts):
417
418
        if ts is None:
            return None
419
        return zlib.compress(tojson(ts).encode('utf-8'))
420
421

    def _deserialize(self, ts, name):
422
        return fromjson(zlib.decompress(ts).decode('utf-8'), name)
423

424
    # serie table handling
425

426
427
    def _ts_table_name(self, seriename):
        # namespace.seriename
428
        return '{}.timeserie.{}'.format(self.namespace, seriename)
429

430
    def _table_definition_for(self, seriename):
431
        return Table(
432
            seriename, self.schema.meta,
433
            Column('id', Integer, primary_key=True),
434
435
            Column('csid', Integer,
                   ForeignKey('{}.changeset.id'.format(self.namespace)),
436
                   index=True, nullable=False),
437
            # constraint: there is either .diff or .snapshot
438
439
            Column('diff', BYTEA),
            Column('snapshot', BYTEA),
440
441
            Column('parent',
                   Integer,
442
443
                   ForeignKey('{}.timeserie.{}.id'.format(self.namespace,
                                                          seriename),
444
                              ondelete='cascade'),
445
446
447
                   nullable=True,
                   unique=True,
                   index=True),
448
            schema='{}.timeserie'.format(self.namespace),
449
            extend_existing=True
450
451
        )

452
    def _make_ts_table(self, cn, name):
453
        tablename = self._ts_table_name(name)
454
        table = self._table_definition_for(name)
455
        table.create(cn)
456
        sql = self.schema.registry.insert().values(
457
458
            name=name,
            table_name=tablename)
459
        cn.execute(sql)
460
461
        return table

462
    def _get_ts_table(self, cn, name):
463
        reg = self.schema.registry
464
465
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
466
        tid = cn.execute(sql).scalar()
467
        if tid:
468
            return self._table_definition_for(name)
469

470
471
    # changeset handling

472
    def _newchangeset(self, cn, author, _insertion_date=None):
473
        table = self.schema.changeset
474
475
        sql = table.insert().values(
            author=author,
476
            insertion_date=_insertion_date or datetime.now())
477
        return cn.execute(sql).inserted_primary_key[0]
478

479
480
    def _latest_csid_for(self, cn, name):
        table = self._get_ts_table(cn, name)
Aurélien Campéas's avatar
Aurélien Campéas committed
481
        sql = select([func.max(table.c.csid)])
482
        return cn.execute(sql).scalar()
483

484
    def _changeset_series(self, cn, csid):
485
        cset_serie = self.schema.changeset_series
486
487
488
        sql = select([cset_serie.c.serie]
        ).where(cset_serie.c.csid == csid)

489
        return [seriename for seriename, in cn.execute(sql).fetchall()]
490
491
492

    # insertion handling

493
    def _get_tip_id(self, cn, table):
494
        " get the *local* id "
495
        sql = select([func.max(table.c.id)])
496
        return cn.execute(sql).scalar()
497

498
499
500
    def _complete_insertion_value(self, value, extra_scalars):
        pass

501
    def _finalize_insertion(self, cn, csid, name):
502
        table = self.schema.changeset_series
503
504
505
506
        sql = table.insert().values(
            csid=csid,
            serie=name
        )
507
        cn.execute(sql)
508

509
510
    # snapshot handling

511
512
    def _purge_snapshot_at(self, cn, table, diffid):
        cn.execute(
513
514
515
516
517
            table.update(
            ).where(table.c.id == diffid
            ).values(snapshot=None)
        )

518
    def _validate_type(self, oldts, newts, name):
519
520
521
        if (oldts is None or
            oldts.isnull().all() or
            newts.isnull().all()):
522
523
524
525
526
527
528
529
            return
        old_type = oldts.dtype
        new_type = newts.dtype
        if new_type != old_type:
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
                name, new_type, old_type)
            raise Exception(m)

530
531
    def _compute_diff_and_newsnapshot(self, cn, table, newts, **extra_scalars):
        snapshot = self._build_snapshot_upto(cn, table)
532
        self._validate_type(snapshot, newts, table.name)
533
534
535
536
537
538
539
540
541
        diff = self._compute_diff(snapshot, newts)

        if len(diff) == 0:
            return None, None

        # full state computation & insertion
        newsnapshot = self._apply_diff(snapshot, diff)
        return diff, newsnapshot

542
543
    def _find_snapshot(self, cn, table, qfilter=(), column='snapshot',
                       from_value_date=None, to_value_date=None):
544
        cset = self.schema.changeset
545
546
547
        sql = select([table.c.id, table.c[column]]
        ).order_by(desc(table.c.id)
        ).limit(1
548
549
        ).where(table.c[column] != None
        ).select_from(table.join(cset))
550
551

        if qfilter:
552
            sql = sql.where(table.c.csid <= cset.c.id)
553
554
555
556
            for filtercb in qfilter:
                sql = sql.where(filtercb(cset, table))

        try:
557
            snapid, snapdata = cn.execute(sql).fetchone()
558
559
            snapdata = subset(self._deserialize(snapdata, table.name),
                              from_value_date, to_value_date)
560
561
        except TypeError:
            return None, None
562
        return snapid, snapdata
563

564
565
    def _build_snapshot_upto(self, cn, table, qfilter=(),
                             from_value_date=None, to_value_date=None):
566
567
568
        snapid, snapshot = self._find_snapshot(cn, table, qfilter,
                                               from_value_date=from_value_date,
                                               to_value_date=to_value_date)
569
570
571
        if snapid is None:
            return None

572
        cset = self.schema.changeset
573
574
        # beware the potential cartesian product
        # between table & cset if there is no qfilter
575
        sql = select([table.c.id,
576
                      table.c.diff,
577
                      table.c.parent,
578
579
                      cset.c.insertion_date]
        ).order_by(table.c.id
580
        ).where(table.c.id > snapid)
581

582
583
584
585
        if qfilter:
            sql = sql.where(table.c.csid == cset.c.id)
            for filtercb in qfilter:
                sql = sql.where(filtercb(cset, table))
586

587
        alldiffs = pd.read_sql(sql, cn)
588
589

        if len(alldiffs) == 0:
590
            return snapshot
591

592
        # initial ts
593
        ts = self._deserialize(alldiffs.loc[0, 'diff'], table.name)
594
        for row in alldiffs.loc[1:].itertuples():
595
596
            diff = subset(self._deserialize(row.diff, table.name),
                          from_value_date, to_value_date)
597
            ts = self._apply_diff(ts, diff)
598
        ts = self._apply_diff(snapshot, ts)
599
600
        assert ts.index.dtype.name == 'datetime64[ns]' or len(ts) == 0
        return ts
601
602
603

    # diff handling

604
605
    def _diff(self, cn, csetid, name):
        table = self._get_ts_table(cn, name)
606
        cset = self.schema.changeset
607
608
609
610
611
612

        def filtercset(sql):
            return sql.where(table.c.csid == cset.c.id
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
613
        tsid = cn.execute(sql).scalar()
614
615
616
617
618
619
620

        if tsid == 1:
            sql = select([table.c.snapshot])
        else:
            sql = select([table.c.diff])
        sql = filtercset(sql)

621
        return self._deserialize(cn.execute(sql).scalar(), name)
622

623
    def _compute_diff(self, fromts, tots):
624
625
        """Compute the difference between fromts and tots
        (like in tots - fromts).
626
627

        """
628
        if fromts is None:
629
            return tots
630
631
632
        fromts = fromts[~fromts.isnull()]
        if not len(fromts):
            return tots
Aurélien Campéas's avatar
Aurélien Campéas committed
633

634
635
636
637
638
        mask_overlap = tots.index.isin(fromts.index)
        fromts_overlap = fromts[tots.index[mask_overlap]]
        tots_overlap = tots[mask_overlap]

        if fromts.dtype == 'float64':
639
            mask_equal = np.isclose(fromts_overlap, tots_overlap,
640
                                    rtol=0, atol=self._precision)
641
642
643
        else:
            mask_equal = fromts_overlap == tots_overlap

644
645
646
        mask_na_equal = fromts_overlap.isnull() & tots_overlap.isnull()
        mask_equal = mask_equal | mask_na_equal

647
648
        diff_overlap = tots[mask_overlap][~mask_equal]
        diff_new = tots[~mask_overlap]
649
        diff_new = diff_new[~diff_new.isnull()]
650
        return pd.concat([diff_overlap, diff_new])
651
652
653

    def _apply_diff(self, base_ts, new_ts):
        """Produce a new ts using base_ts as a base and taking any
654
        intersecting and new values from new_ts.
655
656
657
658
659
660
661
662
663
664

        """
        if base_ts is None:
            return new_ts
        if new_ts is None:
            return base_ts
        result_ts = pd.Series([0.0], index=base_ts.index.union(new_ts.index))
        result_ts[base_ts.index] = base_ts
        result_ts[new_ts.index] = new_ts
        result_ts.sort_index(inplace=True)
665
        result_ts.name = base_ts.name
666
        return result_ts