tsio.py 16.7 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
5
6

import pandas as pd

7
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
8
from sqlalchemy.sql.expression import select, func, desc
9
from sqlalchemy.dialects.postgresql import BYTEA
10

11
from tshistory.schema import tsschema
12
13
from tshistory.util import (
    inject_in_index,
14
15
    mindate,
    maxdate,
16
    num2float,
17
    subset,
18
    SeriesServices,
19
20
    tzaware_serie
)
21
from tshistory.snapshot import Snapshot
22

23
L = logging.getLogger('tshistory.tsio')
24
TABLES = {}
25
26


27
class TimeSerie(SeriesServices):
28
    namespace = 'tsh'
29
    schema = None
30
31
32

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
33
34
        self.schema = tsschema(namespace)
        self.schema.define()
35
        self.metadatacache = {}
36

37
    def insert(self, cn, newts, name, author, _insertion_date=None):
38
        """Create a new revision of a given time series
39

40
        newts: pandas.Series with date index
41
        name: str unique identifier of the serie
42
        author: str free-form author name
43
44
        """
        assert isinstance(newts, pd.Series)
45
        assert not newts.index.duplicated().any()
46

47
        newts = num2float(newts)
48

49
        if not len(newts):
50
            return
51

52
53
54
55
        assert ('<M8[ns]' == newts.index.dtype or
                'datetime' in str(newts.index.dtype) or
                isinstance(newts.index, pd.MultiIndex))

56
        newts.name = name
57
        table = self._get_ts_table(cn, name)
58

59
60
61
62
        if isinstance(newts.index, pd.MultiIndex):
            # we impose an order to survive rountrips
            newts = newts.reorder_levels(sorted(newts.index.names))

63
        if table is None:
64
            return self._create(cn, newts, name, author, _insertion_date)
65

66
        return self._update(cn, table, newts, name, author, _insertion_date)
67

68
    def get(self, cn, name, revision_date=None,
69
70
            from_value_date=None, to_value_date=None,
            _keep_nans=False):
71
72
73
74
75
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

76
        """
77
        table = self._get_ts_table(cn, name)
78
79
        if table is None:
            return
80

81
82
83
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
84
85
86
87
        snap = Snapshot(cn, self, name)
        current = snap.build_upto(qfilter,
                                  from_value_date=from_value_date,
                                  to_value_date=to_value_date)
88

89
        if current is not None and not _keep_nans:
90
            current.name = name
91
            current = current[~current.isnull()]
92
        return current
93

94
    def metadata(self, cn, tsname):
95
96
97
98
99
100
101
102
103
104
105
        """Return metadata dict of timeserie."""
        if tsname in self.metadatacache:
            return self.metadatacache[tsname]
        reg = self.schema.registry
        sql = select([reg.c.metadata]).where(
            reg.c.name == tsname
        )
        meta = cn.execute(sql).scalar()
        self.metadatacache[tsname] = meta
        return meta

106
107
    def get_history(self, cn, name,
                    from_insertion_date=None,
108
109
                    to_insertion_date=None,
                    from_value_date=None,
110
111
                    to_value_date=None,
                    diffmode=False):
112
113
114
115
        table = self._get_ts_table(cn, name)
        if table is None:
            return

116
        # compute diffs above the snapshot
117
118
119
        cset = self.schema.changeset
        diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
        ).order_by(cset.c.id
120
        ).where(table.c.cset == cset.c.id)
121
122
123
124
125
126
127

        if from_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

        diffs = cn.execute(diffsql).fetchall()
128
129
130
131
132
        if not diffs:
            # it's fine to ask for an insertion date range
            # where noting did happen, but you get nothing
            return

133
        if diffmode:
134
            snapshot = Snapshot(cn, self, name)
135
136
137
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
138
139
140
141
                    serie = subset(snapshot.first, from_value_date, to_value_date)
                else:
                    serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
                    serie = self._ensure_tz_consistency(cn, serie)
142
                inject_in_index(serie, revdate)
143
144
145
146
147
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

148
        csid, revdate, diff_ = diffs[0]
149
150
151
        snap = Snapshot(cn, self, name)
        snapshot = snap.build_upto([lambda cset, _: cset.c.id <= csid],
                                   from_value_date, to_value_date)
152

153
        series = [(revdate, subset(snapshot, from_value_date, to_value_date))]
154
        for csid_, revdate, diff in diffs[1:]:
155
156
            diff = subset(self._deserialize(diff, table.name),
                          from_value_date, to_value_date)
157
            diff = self._ensure_tz_consistency(cn, diff)
158

159
            serie = self.patch(series[-1][1], diff)
160
161
162
            series.append((revdate, serie))

        for revdate, serie in series:
163
            inject_in_index(serie, revdate)
164
165
166
167

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
168

169
170
171
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

172
    def latest_insertion_date(self, cn, name):
173
        cset = self.schema.changeset
174
        tstable = self._get_ts_table(cn, name)
175
        sql = select([func.max(cset.c.insertion_date)]
176
        ).where(tstable.c.cset == cset.c.id)
177
        return cn.execute(sql).scalar()
178

179
180
181
182
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
183
184
        sql = select([table.c.cset]).where(
            table.c.cset == cset.c.id
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
            metadata = cn.execute(
                select([cset.c.metadata]).where(cset.c.id == log['rev'])
            ).scalar() or {}
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
212
                cset_serie.c.cset == log['rev']
213
214
215
216
217
218
219
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
220
        cn.execute(table.delete().where(table.c.cset >= csid))
221

222
    def info(self, cn):
223
224
        """Gather global statistics on the current tshistory repository
        """
225
        sql = 'select count(*) from {}.registry'.format(self.namespace)
226
        stats = {'series count': cn.execute(sql).scalar()}
227
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
228
        stats['changeset count'] = cn.execute(sql).scalar()
229
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
230
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
231
232
        return stats

233
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
234
            stripped=False,
235
236
            fromrev=None, torev=None,
            fromdate=None, todate=None):
237
238
239
240
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
241
242
243
244
245
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
246

247
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
248
        ).distinct().order_by(desc(cset.c.id))
249
250
251
252

        if limit:
            sql = sql.limit(limit)

253
254
255
        if names:
            sql = sql.where(reg.c.name.in_(names))

256
257
258
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

259
260
261
262
263
264
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

265
266
267
268
269
270
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

271
272
273
274
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
275
            sql = sql.where(cset.c.id == cset_series.c.cset
276
            ).where(cset_series.c.serie == reg.c.name)
277

278
        rset = cn.execute(sql)
279
        for csetid, author, revdate, meta in rset.fetchall():
280
281
            log.append({'rev': csetid, 'author': author,
                        'date': pd.Timestamp(revdate, tz='utc'),
282
                        'meta': meta or {},
283
                        'names': self._changeset_series(cn, csetid)})
284
285
286

        if diff:
            for rev in log:
287
                rev['diff'] = {name: self.diff_at(cn, rev['rev'], name)
288
289
                               for name in rev['names']}

290
        log.sort(key=lambda rev: rev['rev'])
291
292
        return log

293
294
    # /API
    # Helpers
295

296
297
298
299
300
301
302
303
    # creation / update

    def _create(self, cn, newts, name, author, insertion_date=None):
        # initial insertion
        newts = newts[~newts.isnull()]
        if len(newts) == 0:
            return None
        snapshot = Snapshot(cn, self, name)
304
        csid = self._newchangeset(cn, author, insertion_date)
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        head = snapshot.create(newts)
        value = {
            'cset': csid,
            'snapshot': head
        }
        table = self._make_ts_table(cn, name, newts)
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
        L.info('first insertion of %s (size=%s) by %s',
               name, len(newts), author or self._author)
        return newts

    def _update(self, cn, table, newts, name, author, insertion_date=None):
        self._validate(cn, newts, name)
        snapshot = Snapshot(cn, self, name)
        diff = self.diff(snapshot.last(mindate(newts), maxdate(newts)), newts)
        if not len(diff):
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
            return

326
        csid = self._newchangeset(cn, author, insertion_date)
327
328
329
330
331
332
333
334
335
336
337
338
339
        head = snapshot.update(diff)
        value = {
            'cset': csid,
            'diff': self._serialize(diff),
            'snapshot': head
        }
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)

        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
        return diff

340
341
    # ts serialisation

342
343
344
345
346
    def _ensure_tz_consistency(self, cn, ts):
        """Return timeserie with tz aware index or not depending on metadata
        tzaware.
        """
        assert ts.name is not None
347
        metadata = self.metadata(cn, ts.name)
348
        if metadata and metadata.get('tzaware', False):
349
350
351
            if isinstance(ts.index, pd.MultiIndex):
                for i in range(len(ts.index.levels)):
                    ts.index = ts.index.set_levels(
352
                        ts.index.levels[i].tz_localize('UTC'),
353
354
                        level=i)
                return ts
355
356
357
            return ts.tz_localize('UTC')
        return ts

358
    # serie table handling
359

360
361
    def _ts_table_name(self, seriename):
        # namespace.seriename
362
        return '{}.timeserie.{}'.format(self.namespace, seriename)
363

364
    def _table_definition_for(self, seriename):
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        tablename = self._ts_table_name(seriename)
        table = TABLES.get(tablename)
        if table is None:
            TABLES[tablename] = table = Table(
                seriename, self.schema.meta,
                Column('id', Integer, primary_key=True),
                Column('cset', Integer,
                       ForeignKey('{}.changeset.id'.format(self.namespace)),
                       index=True, nullable=False),
                Column('diff', BYTEA),
                Column('snapshot', Integer,
                       ForeignKey('{}.snapshot.{}.id'.format(
                           self.namespace,
                           seriename)),
                       index=True),
                schema='{}.timeserie'.format(self.namespace),
                extend_existing=True
            )
        return table
384

385
    def _make_ts_table(self, cn, name, ts):
386
        tablename = self._ts_table_name(name)
387
        table = self._table_definition_for(name)
388
        table.create(cn)
389
390
        index = ts.index
        inames = [name for name in index.names if name]
391
        sql = self.schema.registry.insert().values(
392
            name=name,
393
            table_name=tablename,
394
395
396
397
398
399
            metadata={
                'tzaware': tzaware_serie(ts),
                'index_type': index.dtype.name,
                'index_names': inames,
                'value_type': ts.dtypes.name
            },
400
        )
401
        cn.execute(sql)
402
403
        return table

404
    def _get_ts_table(self, cn, name):
405
        reg = self.schema.registry
406
407
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
408
        tid = cn.execute(sql).scalar()
409
        if tid:
410
            return self._table_definition_for(name)
411

412
413
    # changeset handling

414
    def _newchangeset(self, cn, author, _insertion_date=None):
415
        table = self.schema.changeset
416
417
418
        if _insertion_date is not None:
            assert _insertion_date.tzinfo is not None
        idate = pd.Timestamp(_insertion_date or datetime.utcnow(), tz='UTC')
419
420
        sql = table.insert().values(
            author=author,
421
            insertion_date=idate)
422
        return cn.execute(sql).inserted_primary_key[0]
423

424
    def _changeset_series(self, cn, csid):
425
        cset_serie = self.schema.changeset_series
426
        sql = select([cset_serie.c.serie]
427
        ).where(cset_serie.c.cset == csid)
428

429
        return [seriename for seriename, in cn.execute(sql).fetchall()]
430
431
432

    # insertion handling

433
    def _validate(self, cn, ts, name):
434
435
        if ts.isnull().all():
            # ts erasure
436
            return
437
        tstype = ts.dtype
438
        meta = self.metadata(cn, name)
439
        if tstype != meta['value_type']:
440
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
441
                name, tstype, meta['value_type'])
442
            raise Exception(m)
443
        if ts.index.dtype.name != meta['index_type']:
444
            raise Exception('Incompatible index types')
445
446
447
448
449
        inames = [name for name in ts.index.names if name]
        if inames != meta['index_names']:
            raise Exception('Incompatible multi indexes: {} vs {}'.format(
                meta['index_names'], inames)
            )
450

451
452
453
454
455
456
457
    def _finalize_insertion(self, cn, csid, name):
        table = self.schema.changeset_series
        sql = table.insert().values(
            cset=csid,
            serie=name
        )
        cn.execute(sql)
458

459
    def diff_at(self, cn, csetid, name):
460
        table = self._get_ts_table(cn, name)
461
        cset = self.schema.changeset
462
463

        def filtercset(sql):
464
            return sql.where(table.c.cset == cset.c.id
465
466
467
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
468
        tsid = cn.execute(sql).scalar()
469
470

        if tsid == 1:
471
            return Snapshot(cn, self, name).first
472

473
        sql = filtercset(select([table.c.diff]))
474
475
        ts = self._deserialize(cn.execute(sql).scalar(), name)
        return self._ensure_tz_consistency(cn, ts)