tsio.py 18.2 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
5
6

import pandas as pd

7
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
8
from sqlalchemy.sql.expression import select, func, desc
9
from sqlalchemy.dialects.postgresql import BYTEA
10

11
from tshistory.schema import tsschema
12
13
from tshistory.util import (
    inject_in_index,
14
15
    mindate,
    maxdate,
16
    num2float,
17
    subset,
18
    SeriesServices,
19
20
    tzaware_serie
)
21
from tshistory.snapshot import Snapshot
22

23
L = logging.getLogger('tshistory.tsio')
24
TABLES = {}
25
26


27
class TimeSerie(SeriesServices):
28
    _csid = None
29
    namespace = 'tsh'
30
    schema = None
31
32
33

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
34
35
        self.schema = tsschema(namespace)
        self.schema.define()
36
        self.metadatacache = {}
37
38
39

    # API : changeset, insert, get, delete
    @contextmanager
40
    def newchangeset(self, cn, author, _insertion_date=None):
41
42
43
44
45
46
        """A context manager to allow insertion of several series within the
        same changeset identifier

        This allows to group changes to several series, hence
        producing a macro-change.

47
48
        _insertion_date is *only* provided for migration purposes and
        not part of the API.
49
        """
50
        assert self._csid is None
51
        self._csid = self._newchangeset(cn, author, _insertion_date)
52
        self._author = author
53
54
        yield
        del self._csid
55
        del self._author
56

57
    def insert(self, cn, newts, name, author=None, _insertion_date=None):
58
        """Create a new revision of a given time series
59

60
        newts: pandas.Series with date index
61

62
        name: str unique identifier of the serie
63
64
65
66

        author: str free-form author name (mandatory, unless provided
        to the newchangeset context manager).

67
        """
68
69
        assert self._csid or author, 'author is mandatory'
        if self._csid and author:
70
71
            L.info('author r{} will not be used when in a changeset'.format(author))
            author = None
72
        assert isinstance(newts, pd.Series)
73
        assert not newts.index.duplicated().any()
74

75
        newts = num2float(newts)
76

77
        if not len(newts):
78
            return
79

80
81
82
83
        assert ('<M8[ns]' == newts.index.dtype or
                'datetime' in str(newts.index.dtype) or
                isinstance(newts.index, pd.MultiIndex))

84
        newts.name = name
85
        table = self._get_ts_table(cn, name)
86

87
88
89
90
        if isinstance(newts.index, pd.MultiIndex):
            # we impose an order to survive rountrips
            newts = newts.reorder_levels(sorted(newts.index.names))

91
        if table is None:
92
            return self._create(cn, newts, name, author, _insertion_date)
93

94
        return self._update(cn, table, newts, name, author, _insertion_date)
95

96
    def get(self, cn, name, revision_date=None,
97
98
            from_value_date=None, to_value_date=None,
            _keep_nans=False):
99
100
101
102
103
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

104
        """
105
        table = self._get_ts_table(cn, name)
106
107
        if table is None:
            return
108

109
110
111
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
112
113
114
115
        snap = Snapshot(cn, self, name)
        current = snap.build_upto(qfilter,
                                  from_value_date=from_value_date,
                                  to_value_date=to_value_date)
116

117
        if current is not None and not _keep_nans:
118
            current.name = name
119
            current = current[~current.isnull()]
120
        return current
121

122
    def metadata(self, cn, tsname):
123
124
125
126
127
128
129
130
131
132
133
        """Return metadata dict of timeserie."""
        if tsname in self.metadatacache:
            return self.metadatacache[tsname]
        reg = self.schema.registry
        sql = select([reg.c.metadata]).where(
            reg.c.name == tsname
        )
        meta = cn.execute(sql).scalar()
        self.metadatacache[tsname] = meta
        return meta

134
135
    def get_group(self, cn, name, revision_date=None):
        csid = self._latest_csid_for(cn, name)
136
137

        group = {}
138
139
        for seriename in self._changeset_series(cn, csid):
            serie = self.get(cn, seriename, revision_date)
140
141
142
143
            if serie is not None:
                group[seriename] = serie
        return group

144
145
    def get_history(self, cn, name,
                    from_insertion_date=None,
146
147
                    to_insertion_date=None,
                    from_value_date=None,
148
149
                    to_value_date=None,
                    diffmode=False):
150
151
152
153
        table = self._get_ts_table(cn, name)
        if table is None:
            return

154
        # compute diffs above the snapshot
155
156
157
        cset = self.schema.changeset
        diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
        ).order_by(cset.c.id
158
        ).where(table.c.cset == cset.c.id)
159
160
161
162
163
164
165

        if from_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

        diffs = cn.execute(diffsql).fetchall()
166
167
168
169
170
        if not diffs:
            # it's fine to ask for an insertion date range
            # where noting did happen, but you get nothing
            return

171
        if diffmode:
172
            snapshot = Snapshot(cn, self, name)
173
174
175
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
176
177
178
179
                    serie = subset(snapshot.first, from_value_date, to_value_date)
                else:
                    serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
                    serie = self._ensure_tz_consistency(cn, serie)
180
                inject_in_index(serie, revdate)
181
182
183
184
185
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

186
        csid, revdate, diff_ = diffs[0]
187
188
189
        snap = Snapshot(cn, self, name)
        snapshot = snap.build_upto([lambda cset, _: cset.c.id <= csid],
                                   from_value_date, to_value_date)
190

191
        series = [(revdate, subset(snapshot, from_value_date, to_value_date))]
192
        for csid_, revdate, diff in diffs[1:]:
193
194
            diff = subset(self._deserialize(diff, table.name),
                          from_value_date, to_value_date)
195
            diff = self._ensure_tz_consistency(cn, diff)
196

197
            serie = self.patch(series[-1][1], diff)
198
199
200
            series.append((revdate, serie))

        for revdate, serie in series:
201
            inject_in_index(serie, revdate)
202
203
204
205

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
206

207
208
209
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

210
    def latest_insertion_date(self, cn, name):
211
        cset = self.schema.changeset
212
        tstable = self._get_ts_table(cn, name)
213
        sql = select([func.max(cset.c.insertion_date)]
214
        ).where(tstable.c.cset == cset.c.id)
215
        return cn.execute(sql).scalar()
216

217
218
219
220
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
221
222
        sql = select([table.c.cset]).where(
            table.c.cset == cset.c.id
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
            metadata = cn.execute(
                select([cset.c.metadata]).where(cset.c.id == log['rev'])
            ).scalar() or {}
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
250
                cset_serie.c.cset == log['rev']
251
252
253
254
255
256
257
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
258
        cn.execute(table.delete().where(table.c.cset >= csid))
259

260
    def info(self, cn):
261
262
        """Gather global statistics on the current tshistory repository
        """
263
        sql = 'select count(*) from {}.registry'.format(self.namespace)
264
        stats = {'series count': cn.execute(sql).scalar()}
265
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
266
        stats['changeset count'] = cn.execute(sql).scalar()
267
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
268
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
269
270
        return stats

271
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
272
            stripped=False,
273
274
            fromrev=None, torev=None,
            fromdate=None, todate=None):
275
276
277
278
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
279
280
281
282
283
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
284

285
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
286
        ).distinct().order_by(desc(cset.c.id))
287
288
289
290

        if limit:
            sql = sql.limit(limit)

291
292
293
        if names:
            sql = sql.where(reg.c.name.in_(names))

294
295
296
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

297
298
299
300
301
302
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

303
304
305
306
307
308
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

309
310
311
312
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
313
            sql = sql.where(cset.c.id == cset_series.c.cset
314
            ).where(cset_series.c.serie == reg.c.name)
315

316
        rset = cn.execute(sql)
317
        for csetid, author, revdate, meta in rset.fetchall():
318
319
            log.append({'rev': csetid, 'author': author,
                        'date': pd.Timestamp(revdate, tz='utc'),
320
                        'meta': meta or {},
321
                        'names': self._changeset_series(cn, csetid)})
322
323
324

        if diff:
            for rev in log:
325
                rev['diff'] = {name: self.diff_at(cn, rev['rev'], name)
326
327
                               for name in rev['names']}

328
        log.sort(key=lambda rev: rev['rev'])
329
330
        return log

331
332
    # /API
    # Helpers
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    # creation / update

    def _create(self, cn, newts, name, author, insertion_date=None):
        # initial insertion
        newts = newts[~newts.isnull()]
        if len(newts) == 0:
            return None
        snapshot = Snapshot(cn, self, name)
        csid = self._csid or self._newchangeset(cn, author, insertion_date)
        head = snapshot.create(newts)
        value = {
            'cset': csid,
            'snapshot': head
        }
        table = self._make_ts_table(cn, name, newts)
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
        L.info('first insertion of %s (size=%s) by %s',
               name, len(newts), author or self._author)
        return newts

    def _update(self, cn, table, newts, name, author, insertion_date=None):
        self._validate(cn, newts, name)
        snapshot = Snapshot(cn, self, name)
        diff = self.diff(snapshot.last(mindate(newts), maxdate(newts)), newts)
        if not len(diff):
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
            return

        csid = self._csid or self._newchangeset(cn, author, insertion_date)
        head = snapshot.update(diff)
        value = {
            'cset': csid,
            'diff': self._serialize(diff),
            'snapshot': head
        }
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)

        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
        return diff

378
379
    # ts serialisation

380
381
382
383
384
    def _ensure_tz_consistency(self, cn, ts):
        """Return timeserie with tz aware index or not depending on metadata
        tzaware.
        """
        assert ts.name is not None
385
        metadata = self.metadata(cn, ts.name)
386
        if metadata and metadata.get('tzaware', False):
387
388
389
            if isinstance(ts.index, pd.MultiIndex):
                for i in range(len(ts.index.levels)):
                    ts.index = ts.index.set_levels(
390
                        ts.index.levels[i].tz_localize('UTC'),
391
392
                        level=i)
                return ts
393
394
395
            return ts.tz_localize('UTC')
        return ts

396
    # serie table handling
397

398
399
    def _ts_table_name(self, seriename):
        # namespace.seriename
400
        return '{}.timeserie.{}'.format(self.namespace, seriename)
401

402
    def _table_definition_for(self, seriename):
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        tablename = self._ts_table_name(seriename)
        table = TABLES.get(tablename)
        if table is None:
            TABLES[tablename] = table = Table(
                seriename, self.schema.meta,
                Column('id', Integer, primary_key=True),
                Column('cset', Integer,
                       ForeignKey('{}.changeset.id'.format(self.namespace)),
                       index=True, nullable=False),
                Column('diff', BYTEA),
                Column('snapshot', Integer,
                       ForeignKey('{}.snapshot.{}.id'.format(
                           self.namespace,
                           seriename)),
                       index=True),
                schema='{}.timeserie'.format(self.namespace),
                extend_existing=True
            )
        return table
422

423
    def _make_ts_table(self, cn, name, ts):
424
        tablename = self._ts_table_name(name)
425
        table = self._table_definition_for(name)
426
        table.create(cn)
427
428
        index = ts.index
        inames = [name for name in index.names if name]
429
        sql = self.schema.registry.insert().values(
430
            name=name,
431
            table_name=tablename,
432
433
434
435
436
437
            metadata={
                'tzaware': tzaware_serie(ts),
                'index_type': index.dtype.name,
                'index_names': inames,
                'value_type': ts.dtypes.name
            },
438
        )
439
        cn.execute(sql)
440
441
        return table

442
    def _get_ts_table(self, cn, name):
443
        reg = self.schema.registry
444
445
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
446
        tid = cn.execute(sql).scalar()
447
        if tid:
448
            return self._table_definition_for(name)
449

450
451
    # changeset handling

452
    def _newchangeset(self, cn, author, _insertion_date=None):
453
        table = self.schema.changeset
454
455
456
        if _insertion_date is not None:
            assert _insertion_date.tzinfo is not None
        idate = pd.Timestamp(_insertion_date or datetime.utcnow(), tz='UTC')
457
458
        sql = table.insert().values(
            author=author,
459
            insertion_date=idate)
460
        return cn.execute(sql).inserted_primary_key[0]
461

462
463
    def _latest_csid_for(self, cn, name):
        table = self._get_ts_table(cn, name)
464
        sql = select([func.max(table.c.cset)])
465
        return cn.execute(sql).scalar()
466

467
    def _changeset_series(self, cn, csid):
468
        cset_serie = self.schema.changeset_series
469
        sql = select([cset_serie.c.serie]
470
        ).where(cset_serie.c.cset == csid)
471

472
        return [seriename for seriename, in cn.execute(sql).fetchall()]
473
474
475

    # insertion handling

476
    def _validate(self, cn, ts, name):
477
478
        if ts.isnull().all():
            # ts erasure
479
            return
480
        tstype = ts.dtype
481
        meta = self.metadata(cn, name)
482
        if tstype != meta['value_type']:
483
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
484
                name, tstype, meta['value_type'])
485
            raise Exception(m)
486
        if ts.index.dtype.name != meta['index_type']:
487
            raise Exception('Incompatible index types')
488
489
490
491
492
        inames = [name for name in ts.index.names if name]
        if inames != meta['index_names']:
            raise Exception('Incompatible multi indexes: {} vs {}'.format(
                meta['index_names'], inames)
            )
493

494
495
496
497
498
499
500
    def _finalize_insertion(self, cn, csid, name):
        table = self.schema.changeset_series
        sql = table.insert().values(
            cset=csid,
            serie=name
        )
        cn.execute(sql)
501

502
    def diff_at(self, cn, csetid, name):
503
        table = self._get_ts_table(cn, name)
504
        cset = self.schema.changeset
505
506

        def filtercset(sql):
507
            return sql.where(table.c.cset == cset.c.id
508
509
510
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
511
        tsid = cn.execute(sql).scalar()
512
513

        if tsid == 1:
514
            return Snapshot(cn, self, name).first
515

516
        sql = filtercset(select([table.c.diff]))
517
518
        ts = self._deserialize(cn.execute(sql).scalar(), name)
        return self._ensure_tz_consistency(cn, ts)