tsio.py 18.1 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
5
6

import pandas as pd

7
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
8
from sqlalchemy.sql.expression import select, func, desc
9
from sqlalchemy.dialects.postgresql import BYTEA
10

11
from tshistory.schema import tsschema
12
13
from tshistory.util import (
    inject_in_index,
14
15
    mindate,
    maxdate,
16
    num2float,
17
    subset,
18
    SeriesServices,
19
20
    tzaware_serie
)
21
from tshistory.snapshot import Snapshot
22

23
L = logging.getLogger('tshistory.tsio')
24
TABLES = {}
25
26


27
class TimeSerie(SeriesServices):
28
    _csid = None
29
    namespace = 'tsh'
30
    schema = None
31
32
33

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
34
35
        self.schema = tsschema(namespace)
        self.schema.define()
36
        self.metadatacache = {}
37
38
39

    # API : changeset, insert, get, delete
    @contextmanager
40
    def newchangeset(self, cn, author, _insertion_date=None):
41
42
43
44
45
46
        """A context manager to allow insertion of several series within the
        same changeset identifier

        This allows to group changes to several series, hence
        producing a macro-change.

47
48
        _insertion_date is *only* provided for migration purposes and
        not part of the API.
49
        """
50
        assert self._csid is None
51
        self._csid = self._newchangeset(cn, author, _insertion_date)
52
        self._author = author
53
54
        yield
        del self._csid
55
        del self._author
56

57
    def insert(self, cn, newts, name, author=None, _insertion_date=None):
58
        """Create a new revision of a given time series
59

60
        newts: pandas.Series with date index
61

62
        name: str unique identifier of the serie
63
64
65
66

        author: str free-form author name (mandatory, unless provided
        to the newchangeset context manager).

67
        """
68
69
        assert self._csid or author, 'author is mandatory'
        if self._csid and author:
70
71
            L.info('author r{} will not be used when in a changeset'.format(author))
            author = None
72
        assert isinstance(newts, pd.Series)
73
        assert not newts.index.duplicated().any()
74

75
        newts = num2float(newts)
76

77
        if not len(newts):
78
            return
79

80
81
82
83
        assert ('<M8[ns]' == newts.index.dtype or
                'datetime' in str(newts.index.dtype) or
                isinstance(newts.index, pd.MultiIndex))

84
        newts.name = name
85
        table = self._get_ts_table(cn, name)
86

87
88
89
90
        if isinstance(newts.index, pd.MultiIndex):
            # we impose an order to survive rountrips
            newts = newts.reorder_levels(sorted(newts.index.names))

91
        if table is None:
92
            return self._create(cn, newts, name, author, _insertion_date)
93

94
        return self._update(cn, table, newts, name, author, _insertion_date)
95

96
97
    def get(self, cn, name, revision_date=None,
            from_value_date=None, to_value_date=None):
98
99
100
101
102
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

103
        """
104
        table = self._get_ts_table(cn, name)
105
106
        if table is None:
            return
107

108
109
110
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
111
112
113
114
        snap = Snapshot(cn, self, name)
        current = snap.build_upto(qfilter,
                                  from_value_date=from_value_date,
                                  to_value_date=to_value_date)
115

116
117
        if current is not None:
            current.name = name
118
            current = current[~current.isnull()]
119
        return current
120

121
    def metadata(self, cn, tsname):
122
123
124
125
126
127
128
129
130
131
132
        """Return metadata dict of timeserie."""
        if tsname in self.metadatacache:
            return self.metadatacache[tsname]
        reg = self.schema.registry
        sql = select([reg.c.metadata]).where(
            reg.c.name == tsname
        )
        meta = cn.execute(sql).scalar()
        self.metadatacache[tsname] = meta
        return meta

133
134
    def get_group(self, cn, name, revision_date=None):
        csid = self._latest_csid_for(cn, name)
135
136

        group = {}
137
138
        for seriename in self._changeset_series(cn, csid):
            serie = self.get(cn, seriename, revision_date)
139
140
141
142
            if serie is not None:
                group[seriename] = serie
        return group

143
144
    def get_history(self, cn, name,
                    from_insertion_date=None,
145
146
                    to_insertion_date=None,
                    from_value_date=None,
147
148
                    to_value_date=None,
                    diffmode=False):
149
150
151
152
        table = self._get_ts_table(cn, name)
        if table is None:
            return

153
        # compute diffs above the snapshot
154
155
156
        cset = self.schema.changeset
        diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
        ).order_by(cset.c.id
157
        ).where(table.c.cset == cset.c.id)
158
159
160
161
162
163
164

        if from_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

        diffs = cn.execute(diffsql).fetchall()
165
166
167
168
169
        if not diffs:
            # it's fine to ask for an insertion date range
            # where noting did happen, but you get nothing
            return

170
        if diffmode:
171
            snapshot = Snapshot(cn, self, name)
172
173
174
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
175
176
177
178
                    serie = subset(snapshot.first, from_value_date, to_value_date)
                else:
                    serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
                    serie = self._ensure_tz_consistency(cn, serie)
179
                inject_in_index(serie, revdate)
180
181
182
183
184
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

185
        csid, revdate, diff_ = diffs[0]
186
187
188
        snap = Snapshot(cn, self, name)
        snapshot = snap.build_upto([lambda cset, _: cset.c.id <= csid],
                                   from_value_date, to_value_date)
189

190
        series = [(revdate, subset(snapshot, from_value_date, to_value_date))]
191
        for csid_, revdate, diff in diffs[1:]:
192
193
            diff = subset(self._deserialize(diff, table.name),
                          from_value_date, to_value_date)
194
            diff = self._ensure_tz_consistency(cn, diff)
195

196
            serie = self.patch(series[-1][1], diff)
197
198
199
            series.append((revdate, serie))

        for revdate, serie in series:
200
            inject_in_index(serie, revdate)
201
202
203
204

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
205

206
207
208
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

209
    def latest_insertion_date(self, cn, name):
210
        cset = self.schema.changeset
211
        tstable = self._get_ts_table(cn, name)
212
        sql = select([func.max(cset.c.insertion_date)]
213
        ).where(tstable.c.cset == cset.c.id)
214
        return cn.execute(sql).scalar()
215

216
217
218
219
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
220
221
        sql = select([table.c.cset]).where(
            table.c.cset == cset.c.id
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
            metadata = cn.execute(
                select([cset.c.metadata]).where(cset.c.id == log['rev'])
            ).scalar() or {}
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
249
                cset_serie.c.cset == log['rev']
250
251
252
253
254
255
256
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
257
        cn.execute(table.delete().where(table.c.cset >= csid))
258

259
    def info(self, cn):
260
261
        """Gather global statistics on the current tshistory repository
        """
262
        sql = 'select count(*) from {}.registry'.format(self.namespace)
263
        stats = {'series count': cn.execute(sql).scalar()}
264
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
265
        stats['changeset count'] = cn.execute(sql).scalar()
266
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
267
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
268
269
        return stats

270
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
271
            stripped=False,
272
273
            fromrev=None, torev=None,
            fromdate=None, todate=None):
274
275
276
277
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
278
279
280
281
282
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
283

284
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
285
        ).distinct().order_by(desc(cset.c.id))
286
287
288
289

        if limit:
            sql = sql.limit(limit)

290
291
292
        if names:
            sql = sql.where(reg.c.name.in_(names))

293
294
295
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

296
297
298
299
300
301
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

302
303
304
305
306
307
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

308
309
310
311
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
312
            sql = sql.where(cset.c.id == cset_series.c.cset
313
            ).where(cset_series.c.serie == reg.c.name)
314

315
        rset = cn.execute(sql)
316
        for csetid, author, revdate, meta in rset.fetchall():
317
318
            log.append({'rev': csetid, 'author': author,
                        'date': pd.Timestamp(revdate, tz='utc'),
319
                        'meta': meta or {},
320
                        'names': self._changeset_series(cn, csetid)})
321
322
323

        if diff:
            for rev in log:
324
                rev['diff'] = {name: self.diff_at(cn, rev['rev'], name)
325
326
                               for name in rev['names']}

327
        log.sort(key=lambda rev: rev['rev'])
328
329
        return log

330
331
    # /API
    # Helpers
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    # creation / update

    def _create(self, cn, newts, name, author, insertion_date=None):
        # initial insertion
        newts = newts[~newts.isnull()]
        if len(newts) == 0:
            return None
        snapshot = Snapshot(cn, self, name)
        csid = self._csid or self._newchangeset(cn, author, insertion_date)
        head = snapshot.create(newts)
        value = {
            'cset': csid,
            'snapshot': head
        }
        table = self._make_ts_table(cn, name, newts)
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
        L.info('first insertion of %s (size=%s) by %s',
               name, len(newts), author or self._author)
        return newts

    def _update(self, cn, table, newts, name, author, insertion_date=None):
        self._validate(cn, newts, name)
        snapshot = Snapshot(cn, self, name)
        diff = self.diff(snapshot.last(mindate(newts), maxdate(newts)), newts)
        if not len(diff):
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
            return

        csid = self._csid or self._newchangeset(cn, author, insertion_date)
        head = snapshot.update(diff)
        value = {
            'cset': csid,
            'diff': self._serialize(diff),
            'snapshot': head
        }
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)

        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
        return diff

377
378
    # ts serialisation

379
380
381
382
383
    def _ensure_tz_consistency(self, cn, ts):
        """Return timeserie with tz aware index or not depending on metadata
        tzaware.
        """
        assert ts.name is not None
384
        metadata = self.metadata(cn, ts.name)
385
        if metadata and metadata.get('tzaware', False):
386
387
388
            if isinstance(ts.index, pd.MultiIndex):
                for i in range(len(ts.index.levels)):
                    ts.index = ts.index.set_levels(
389
                        ts.index.levels[i].tz_localize('UTC'),
390
391
                        level=i)
                return ts
392
393
394
            return ts.tz_localize('UTC')
        return ts

395
    # serie table handling
396

397
398
    def _ts_table_name(self, seriename):
        # namespace.seriename
399
        return '{}.timeserie.{}'.format(self.namespace, seriename)
400

401
    def _table_definition_for(self, seriename):
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        tablename = self._ts_table_name(seriename)
        table = TABLES.get(tablename)
        if table is None:
            TABLES[tablename] = table = Table(
                seriename, self.schema.meta,
                Column('id', Integer, primary_key=True),
                Column('cset', Integer,
                       ForeignKey('{}.changeset.id'.format(self.namespace)),
                       index=True, nullable=False),
                Column('diff', BYTEA),
                Column('snapshot', Integer,
                       ForeignKey('{}.snapshot.{}.id'.format(
                           self.namespace,
                           seriename)),
                       index=True),
                schema='{}.timeserie'.format(self.namespace),
                extend_existing=True
            )
        return table
421

422
    def _make_ts_table(self, cn, name, ts):
423
        tablename = self._ts_table_name(name)
424
        table = self._table_definition_for(name)
425
        table.create(cn)
426
427
        index = ts.index
        inames = [name for name in index.names if name]
428
        sql = self.schema.registry.insert().values(
429
            name=name,
430
            table_name=tablename,
431
432
433
434
435
436
            metadata={
                'tzaware': tzaware_serie(ts),
                'index_type': index.dtype.name,
                'index_names': inames,
                'value_type': ts.dtypes.name
            },
437
        )
438
        cn.execute(sql)
439
440
        return table

441
    def _get_ts_table(self, cn, name):
442
        reg = self.schema.registry
443
444
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
445
        tid = cn.execute(sql).scalar()
446
        if tid:
447
            return self._table_definition_for(name)
448

449
450
    # changeset handling

451
    def _newchangeset(self, cn, author, _insertion_date=None):
452
        table = self.schema.changeset
453
454
455
        if _insertion_date is not None:
            assert _insertion_date.tzinfo is not None
        idate = pd.Timestamp(_insertion_date or datetime.utcnow(), tz='UTC')
456
457
        sql = table.insert().values(
            author=author,
458
            insertion_date=idate)
459
        return cn.execute(sql).inserted_primary_key[0]
460

461
462
    def _latest_csid_for(self, cn, name):
        table = self._get_ts_table(cn, name)
463
        sql = select([func.max(table.c.cset)])
464
        return cn.execute(sql).scalar()
465

466
    def _changeset_series(self, cn, csid):
467
        cset_serie = self.schema.changeset_series
468
        sql = select([cset_serie.c.serie]
469
        ).where(cset_serie.c.cset == csid)
470

471
        return [seriename for seriename, in cn.execute(sql).fetchall()]
472
473
474

    # insertion handling

475
    def _validate(self, cn, ts, name):
476
477
        if ts.isnull().all():
            # ts erasure
478
            return
479
        tstype = ts.dtype
480
        meta = self.metadata(cn, name)
481
        if tstype != meta['value_type']:
482
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
483
                name, tstype, meta['value_type'])
484
            raise Exception(m)
485
        if ts.index.dtype.name != meta['index_type']:
486
            raise Exception('Incompatible index types')
487
488
489
490
491
        inames = [name for name in ts.index.names if name]
        if inames != meta['index_names']:
            raise Exception('Incompatible multi indexes: {} vs {}'.format(
                meta['index_names'], inames)
            )
492

493
494
495
496
497
498
499
    def _finalize_insertion(self, cn, csid, name):
        table = self.schema.changeset_series
        sql = table.insert().values(
            cset=csid,
            serie=name
        )
        cn.execute(sql)
500

501
    def diff_at(self, cn, csetid, name):
502
        table = self._get_ts_table(cn, name)
503
        cset = self.schema.changeset
504
505

        def filtercset(sql):
506
            return sql.where(table.c.cset == cset.c.id
507
508
509
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
510
        tsid = cn.execute(sql).scalar()
511
512

        if tsid == 1:
513
            return Snapshot(cn, self, name).first
514

515
        sql = filtercset(select([table.c.diff]))
516
517
        ts = self._deserialize(cn.execute(sql).scalar(), name)
        return self._ensure_tz_consistency(cn, ts)