tsio.py 16.8 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
5
6

import pandas as pd

7
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
8
from sqlalchemy.sql.expression import select, func, desc
9
from sqlalchemy.dialects.postgresql import BYTEA
10

11
from tshistory.schema import tsschema
12
13
from tshistory.util import (
    inject_in_index,
14
15
    mindate,
    maxdate,
16
    num2float,
17
    subset,
18
    SeriesServices,
19
20
    tzaware_serie
)
21
from tshistory.snapshot import Snapshot
22

23
L = logging.getLogger('tshistory.tsio')
24
TABLES = {}
25
26


27
class TimeSerie(SeriesServices):
28
    namespace = 'tsh'
29
    schema = None
30
31
32

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
33
34
        self.schema = tsschema(namespace)
        self.schema.define()
35
        self.metadatacache = {}
36

37
    def insert(self, cn, newts, name, author, _insertion_date=None):
38
        """Create a new revision of a given time series
39

40
        newts: pandas.Series with date index
41
        name: str unique identifier of the serie
42
        author: str free-form author name
43
44
        """
        assert isinstance(newts, pd.Series)
45
        assert not newts.index.duplicated().any()
46

47
        newts = num2float(newts)
48

49
        if not len(newts):
50
            return
51

52
        assert ('<M8[ns]' == newts.index.dtype or
53
                'datetime' in str(newts.index.dtype) and not
54
55
                isinstance(newts.index, pd.MultiIndex))

56
        newts.name = name
57
        table = self._get_ts_table(cn, name)
58

59
        if table is None:
60
            return self._create(cn, newts, name, author, _insertion_date)
61

62
        return self._update(cn, table, newts, name, author, _insertion_date)
63

64
    def get(self, cn, name, revision_date=None,
65
66
            from_value_date=None, to_value_date=None,
            _keep_nans=False):
67
68
69
70
71
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

72
        """
73
        table = self._get_ts_table(cn, name)
74
75
        if table is None:
            return
76

77
78
79
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
80
        snap = Snapshot(cn, self, name)
81
82
83
        _, current = snap.find(qfilter,
                               from_value_date=from_value_date,
                               to_value_date=to_value_date)
84

85
        if current is not None and not _keep_nans:
86
            current.name = name
87
            current = current[~current.isnull()]
88
        return current
89

90
    def metadata(self, cn, tsname):
91
92
93
94
95
96
97
98
99
100
101
        """Return metadata dict of timeserie."""
        if tsname in self.metadatacache:
            return self.metadatacache[tsname]
        reg = self.schema.registry
        sql = select([reg.c.metadata]).where(
            reg.c.name == tsname
        )
        meta = cn.execute(sql).scalar()
        self.metadatacache[tsname] = meta
        return meta

102
103
    def get_history(self, cn, name,
                    from_insertion_date=None,
104
105
                    to_insertion_date=None,
                    from_value_date=None,
106
                    to_value_date=None,
107
108
                    deltabefore=None,
                    deltaafter=None,
109
                    diffmode=False):
110
111
112
113
        table = self._get_ts_table(cn, name)
        if table is None:
            return

114
115
116
117
118
        if deltabefore is not None or deltaafter is not None:
            assert diffmode is False
            assert from_value_date is None
            assert to_value_date is None

119
        cset = self.schema.changeset
120

121
        if diffmode:
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            # compute diffs above the snapshot
            diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
            ).order_by(cset.c.id
            ).where(table.c.cset == cset.c.id)

            if from_insertion_date:
                diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
            if to_insertion_date:
                diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

            diffs = cn.execute(diffsql).fetchall()
            if not diffs:
                # it's fine to ask for an insertion date range
                # where noting did happen, but you get nothing
                return

138
            snapshot = Snapshot(cn, self, name)
139
140
141
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
142
143
144
145
                    serie = subset(snapshot.first, from_value_date, to_value_date)
                else:
                    serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
                    serie = self._ensure_tz_consistency(cn, serie)
146
                inject_in_index(serie, revdate)
147
148
149
150
151
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

152
153
154
155
156
157
158
159
160
161
162
163
        revsql = select(
            [cset.c.id, cset.c.insertion_date]
        ).order_by(
            cset.c.id
        ).where(
            table.c.cset == cset.c.id
        )

        if from_insertion_date:
            revsql = revsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            revsql = revsql.where(cset.c.insertion_date <= to_insertion_date)
164

165
166
167
        revs = cn.execute(revsql).fetchall()
        if not revs:
            return
168

169
170
171
        snapshot = Snapshot(cn, self, name)
        series = []
        for csid, idate in revs:
172
173
174
175
176
177
178
            if deltabefore or deltaafter:
                from_value_date = idate
                to_value_date = idate
                if deltabefore:
                    from_value_date = idate - deltabefore
                if deltaafter:
                    to_value_date = idate + deltaafter
179
180
181
182
183
184
            series.append((
                idate,
                snapshot.find([lambda cset, _: cset.c.id == csid],
                               from_value_date=from_value_date,
                               to_value_date=to_value_date)[1]
            ))
185
186

        for revdate, serie in series:
187
            inject_in_index(serie, revdate)
188
189
190
191

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
192

193
194
195
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

196
    def latest_insertion_date(self, cn, name):
197
        cset = self.schema.changeset
198
        tstable = self._get_ts_table(cn, name)
199
        sql = select([func.max(cset.c.insertion_date)]
200
        ).where(tstable.c.cset == cset.c.id)
201
        return cn.execute(sql).scalar()
202

203
204
205
206
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
207
208
        sql = select([table.c.cset]).where(
            table.c.cset == cset.c.id
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
            metadata = cn.execute(
                select([cset.c.metadata]).where(cset.c.id == log['rev'])
            ).scalar() or {}
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
236
                cset_serie.c.cset == log['rev']
237
238
239
240
241
242
243
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
244
        cn.execute(table.delete().where(table.c.cset >= csid))
245

246
    def info(self, cn):
247
248
        """Gather global statistics on the current tshistory repository
        """
249
        sql = 'select count(*) from {}.registry'.format(self.namespace)
250
        stats = {'series count': cn.execute(sql).scalar()}
251
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
252
        stats['changeset count'] = cn.execute(sql).scalar()
253
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
254
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
255
256
        return stats

257
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
258
            stripped=False,
259
260
            fromrev=None, torev=None,
            fromdate=None, todate=None):
261
262
263
264
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
265
266
267
268
269
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
270

271
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
272
        ).distinct().order_by(desc(cset.c.id))
273
274
275
276

        if limit:
            sql = sql.limit(limit)

277
278
279
        if names:
            sql = sql.where(reg.c.name.in_(names))

280
281
282
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

283
284
285
286
287
288
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

289
290
291
292
293
294
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

295
296
297
298
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
299
            sql = sql.where(cset.c.id == cset_series.c.cset
300
            ).where(cset_series.c.serie == reg.c.name)
301

302
        rset = cn.execute(sql)
303
        for csetid, author, revdate, meta in rset.fetchall():
304
305
            log.append({'rev': csetid, 'author': author,
                        'date': pd.Timestamp(revdate, tz='utc'),
306
                        'meta': meta or {},
307
                        'names': self._changeset_series(cn, csetid)})
308
309
310

        if diff:
            for rev in log:
311
                rev['diff'] = {name: self.diff_at(cn, rev['rev'], name)
312
313
                               for name in rev['names']}

314
        log.sort(key=lambda rev: rev['rev'])
315
316
        return log

317
318
    # /API
    # Helpers
319

320
321
322
323
324
325
326
    # creation / update

    def _create(self, cn, newts, name, author, insertion_date=None):
        # initial insertion
        if len(newts) == 0:
            return None
        snapshot = Snapshot(cn, self, name)
327
        csid = self._newchangeset(cn, author, insertion_date)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        head = snapshot.create(newts)
        value = {
            'cset': csid,
            'snapshot': head
        }
        table = self._make_ts_table(cn, name, newts)
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
        L.info('first insertion of %s (size=%s) by %s',
               name, len(newts), author or self._author)
        return newts

    def _update(self, cn, table, newts, name, author, insertion_date=None):
        self._validate(cn, newts, name)
        snapshot = Snapshot(cn, self, name)
        diff = self.diff(snapshot.last(mindate(newts), maxdate(newts)), newts)
        if not len(diff):
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
            return

349
        csid = self._newchangeset(cn, author, insertion_date)
350
351
352
353
354
355
356
357
358
359
360
361
362
        head = snapshot.update(diff)
        value = {
            'cset': csid,
            'diff': self._serialize(diff),
            'snapshot': head
        }
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)

        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
        return diff

363
364
    # ts serialisation

365
366
367
368
369
    def _ensure_tz_consistency(self, cn, ts):
        """Return timeserie with tz aware index or not depending on metadata
        tzaware.
        """
        assert ts.name is not None
370
        metadata = self.metadata(cn, ts.name)
371
372
373
374
        if metadata and metadata.get('tzaware', False):
            return ts.tz_localize('UTC')
        return ts

375
    # serie table handling
376

377
378
    def _ts_table_name(self, seriename):
        # namespace.seriename
379
        return '{}.timeserie.{}'.format(self.namespace, seriename)
380

381
    def _table_definition_for(self, seriename):
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        tablename = self._ts_table_name(seriename)
        table = TABLES.get(tablename)
        if table is None:
            TABLES[tablename] = table = Table(
                seriename, self.schema.meta,
                Column('id', Integer, primary_key=True),
                Column('cset', Integer,
                       ForeignKey('{}.changeset.id'.format(self.namespace)),
                       index=True, nullable=False),
                Column('diff', BYTEA),
                Column('snapshot', Integer,
                       ForeignKey('{}.snapshot.{}.id'.format(
                           self.namespace,
                           seriename)),
                       index=True),
                schema='{}.timeserie'.format(self.namespace),
                extend_existing=True
            )
        return table
401

402
    def _make_ts_table(self, cn, name, ts):
403
        tablename = self._ts_table_name(name)
404
        table = self._table_definition_for(name)
405
        table.create(cn)
406
407
        index = ts.index
        inames = [name for name in index.names if name]
408
        sql = self.schema.registry.insert().values(
409
            name=name,
410
            table_name=tablename,
411
412
413
414
415
416
            metadata={
                'tzaware': tzaware_serie(ts),
                'index_type': index.dtype.name,
                'index_names': inames,
                'value_type': ts.dtypes.name
            },
417
        )
418
        cn.execute(sql)
419
420
        return table

421
    def _get_ts_table(self, cn, name):
422
        reg = self.schema.registry
423
424
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
425
        tid = cn.execute(sql).scalar()
426
        if tid:
427
            return self._table_definition_for(name)
428

429
430
    # changeset handling

431
    def _newchangeset(self, cn, author, _insertion_date=None):
432
        table = self.schema.changeset
433
434
435
        if _insertion_date is not None:
            assert _insertion_date.tzinfo is not None
        idate = pd.Timestamp(_insertion_date or datetime.utcnow(), tz='UTC')
436
437
        sql = table.insert().values(
            author=author,
438
            insertion_date=idate)
439
        return cn.execute(sql).inserted_primary_key[0]
440

441
    def _changeset_series(self, cn, csid):
442
        cset_serie = self.schema.changeset_series
443
        sql = select([cset_serie.c.serie]
444
        ).where(cset_serie.c.cset == csid)
445

446
        return [seriename for seriename, in cn.execute(sql).fetchall()]
447
448
449

    # insertion handling

450
    def _validate(self, cn, ts, name):
451
452
        if ts.isnull().all():
            # ts erasure
453
            return
454
        tstype = ts.dtype
455
        meta = self.metadata(cn, name)
456
        if tstype != meta['value_type']:
457
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
458
                name, tstype, meta['value_type'])
459
            raise Exception(m)
460
        if ts.index.dtype.name != meta['index_type']:
461
            raise Exception('Incompatible index types')
462

463
464
465
466
467
468
469
    def _finalize_insertion(self, cn, csid, name):
        table = self.schema.changeset_series
        sql = table.insert().values(
            cset=csid,
            serie=name
        )
        cn.execute(sql)
470

471
    def diff_at(self, cn, csetid, name):
472
        table = self._get_ts_table(cn, name)
473
        cset = self.schema.changeset
474
475

        def filtercset(sql):
476
            return sql.where(table.c.cset == cset.c.id
477
478
479
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
480
        tsid = cn.execute(sql).scalar()
481
482

        if tsid == 1:
483
            return Snapshot(cn, self, name).first
484

485
        sql = filtercset(select([table.c.diff]))
486
487
        ts = self._deserialize(cn.execute(sql).scalar(), name)
        return self._ensure_tz_consistency(cn, ts)