tsio.py 16.9 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
5
import pickle
import zlib
6
7
8
9

import pandas as pd
import numpy as np

10
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
11
from sqlalchemy.sql.expression import select, func, desc
12
from sqlalchemy.dialects.postgresql import BYTEA
13

14
from tshistory import schema
15
16


17
L = logging.getLogger('tshistory.tsio')
18
19


20
21
22
23
def tojson(ts):
    if ts is None:
        return None

24
25
    if not isinstance(ts.index, pd.MultiIndex):
        return ts.to_json(date_format='iso')
26

27
28
29
    # multi index case
    return ts.to_frame().reset_index().to_json(date_format='iso')

Aurélien Campéas's avatar
Aurélien Campéas committed
30

31
32
33
34
35
def num2float(pdobj):
    # get a Series or a Dataframe column
    if str(pdobj.dtype).startswith('int'):
        return pdobj.astype('float64')
    return pdobj
36

Aurélien Campéas's avatar
Aurélien Campéas committed
37

38
def fromjson(jsonb, tsname):
39
40
41
42
    return _fromjson(jsonb, tsname).fillna(value=np.nan)


def _fromjson(jsonb, tsname):
43
44
45
    if jsonb == '{}':
        return pd.Series(name=tsname)

46
47
    result = pd.read_json(jsonb, typ='series', dtype=False)
    if isinstance(result.index, pd.DatetimeIndex):
48
        result = num2float(result)
49
50
51
52
53
54
55
56
        return result

    # multi index case
    columns = result.index.values.tolist()
    columns.remove(tsname)
    result = pd.read_json(jsonb, typ='frame',
                          convert_dates=columns)
    result.set_index(sorted(columns), inplace=True)
Aurélien Campéas's avatar
Aurélien Campéas committed
57
    return num2float(result.iloc[:, 0])  # get a Series object
58
59


60
class TimeSerie(object):
61
    _csid = None
62
    _snapshot_interval = 10
63
    _precision = 1e-14
64
65
66

    # API : changeset, insert, get, delete
    @contextmanager
67
    def newchangeset(self, cn, author, _insertion_date=None):
68
69
70
71
72
73
        """A context manager to allow insertion of several series within the
        same changeset identifier

        This allows to group changes to several series, hence
        producing a macro-change.

74
75
        _insertion_date is *only* provided for migration purposes and
        not part of the API.
76
        """
77
        assert self._csid is None
78
        self._csid = self._newchangeset(cn, author, _insertion_date)
79
        self._author = author
80
81
        yield
        del self._csid
82
        del self._author
83

84
    def insert(self, cn, newts, name, author=None,
85
               extra_scalars={}):
86
        """Create a new revision of a given time series
87

88
        newts: pandas.Series with date index
89

90
        name: str unique identifier of the serie
91
92
93
94

        author: str free-form author name (mandatory, unless provided
        to the newchangeset context manager).

95
        """
96
97
        assert self._csid or author, 'author is mandatory'
        if self._csid and author:
98
            L.info('author will not be used when in a changeset')
99
        assert isinstance(newts, pd.Series)
100
        assert not newts.index.duplicated().any()
101

102
        newts = num2float(newts)
103

104
        if not len(newts):
105
            return
106

107
        newts.name = name
108
        table = self._get_ts_table(cn, name)
109

110
111
112
113
        if isinstance(newts.index, pd.MultiIndex):
            # we impose an order to survive rountrips
            newts = newts.reorder_levels(sorted(newts.index.names))

114
115
        if table is None:
            # initial insertion
116
117
            if newts.isnull().all():
                return None
118
            newts = newts[~newts.isnull()]
119
120
            table = self._make_ts_table(cn, name)
            csid = self._csid or self._newchangeset(cn, author)
121
            value = {
122
                'csid': csid,
123
                'snapshot': self._serialize(newts),
124
            }
125
126
            # callback for extenders
            self._complete_insertion_value(value, extra_scalars)
127
128
            cn.execute(table.insert().values(value))
            self._finalize_insertion(cn, csid, name)
129
130
            L.info('first insertion of %s (size=%s) by %s',
                   name, len(newts), author or self._author)
131
            return newts
132

133
        diff, newsnapshot = self._compute_diff_and_newsnapshot(
134
            cn, table, newts, **extra_scalars
135
136
        )
        if diff is None:
137
138
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
139
140
            return

141
142
        tip_id = self._get_tip_id(cn, table)
        csid = self._csid or self._newchangeset(cn, author)
143
        value = {
144
            'csid': csid,
145
146
            'diff': self._serialize(diff),
            'snapshot': self._serialize(newsnapshot),
147
148
149
150
            'parent': tip_id,
        }
        # callback for extenders
        self._complete_insertion_value(value, extra_scalars)
151
152
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
153

154
        if tip_id > 1 and tip_id % self._snapshot_interval:
155
            self._purge_snapshot_at(cn, table, tip_id)
156
157
        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
158
        return diff
159

160
    def get(self, cn, name, revision_date=None):
161
162
163
164
165
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

166
        """
167
        table = self._get_ts_table(cn, name)
168
169
        if table is None:
            return
170

171
172
173
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
174
        current = self._build_snapshot_upto(cn, table, qfilter)
175

176
177
        if current is not None:
            current.name = name
178
            current = current[~current.isnull()]
179
        return current
180

181
182
    def get_group(self, cn, name, revision_date=None):
        csid = self._latest_csid_for(cn, name)
183
184

        group = {}
185
186
        for seriename in self._changeset_series(cn, csid):
            serie = self.get(cn, seriename, revision_date)
187
188
189
190
            if serie is not None:
                group[seriename] = serie
        return group

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def get_history(self, cn, name,
                    from_insertion_date=None,
                    to_insertion_date=None):
        table = self._get_ts_table(cn, name)
        if table is None:
            return

        logs = self.log(cn, names=[name],
                        fromdate=from_insertion_date,
                        todate=to_insertion_date)
        series = []
        for log in logs:
            serie = self.get(cn, name, revision_date=log['date'])
            revdate = pd.Timestamp(log['date'])
            mindex = [(revdate, valuestamp) for valuestamp in serie.index]
            serie.index = pd.MultiIndex.from_tuples(mindex, names=['insertion_date', 'value_date'])
            series.append(serie)
        return pd.concat(series)

210
211
212
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

213
    def latest_insertion_date(self, cn, name):
214
        cset = schema.changeset
215
        tstable = self._get_ts_table(cn, name)
216
217
        sql = select([func.max(cset.c.insertion_date)]
        ).where(tstable.c.csid == cset.c.id)
218
        return cn.execute(sql).scalar()
219

220
    def info(self, cn):
221
222
223
        """Gather global statistics on the current tshistory repository
        """
        sql = 'select count(*) from registry'
224
        stats = {'series count': cn.execute(sql).scalar()}
225
        sql = 'select max(id) from changeset'
226
        stats['changeset count'] = cn.execute(sql).scalar()
227
        sql = 'select distinct name from registry order by name'
228
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
229
230
        return stats

231
232
233
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
            fromrev=None, torev=None,
            fromdate=None, todate=None):
234
235
236
237
238
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
        cset, cset_series, reg = schema.changeset, schema.changeset_series, schema.registry
239

240
241
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date]
        ).distinct().order_by(desc(cset.c.id))
242
243
244
245

        if limit:
            sql = sql.limit(limit)

246
247
248
        if names:
            sql = sql.where(reg.c.name.in_(names))

249
250
251
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

252
253
254
255
256
257
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

258
259
260
261
262
263
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

264
        sql = sql.where(cset.c.id == cset_series.c.csid
265
266
        ).where(cset_series.c.serie == reg.c.name)

267
        rset = cn.execute(sql)
268
269
        for csetid, author, revdate in rset.fetchall():
            log.append({'rev': csetid, 'author': author, 'date': revdate,
270
                        'names': self._changeset_series(cn, csetid)})
271
272
273

        if diff:
            for rev in log:
274
                rev['diff'] = {name: self._diff(cn, rev['rev'], name)
275
276
                               for name in rev['names']}

277
        log.sort(key=lambda rev: rev['rev'])
278
279
        return log

280
281
    # /API
    # Helpers
282

283
284
285
    # ts serialisation

    def _serialize(self, ts):
286
        return zlib.compress(tojson(ts).encode('utf-8'))
287
288

    def _deserialize(self, ts, name):
289
        return fromjson(zlib.decompress(ts).decode('utf-8'), name)
290

291
    # serie table handling
292

293
294
295
    def _ts_table_name(self, seriename):
        # namespace.seriename
        return 'timeserie.%s' % seriename
296

297
    def _table_definition_for(self, seriename):
298
        return Table(
299
            seriename, schema.meta,
300
            Column('id', Integer, primary_key=True),
301
            Column('csid', Integer, ForeignKey('changeset.id'),
302
                   index=True, nullable=False),
303
            # constraint: there is either .diff or .snapshot
304
305
            Column('diff', BYTEA),
            Column('snapshot', BYTEA),
306
307
            Column('parent',
                   Integer,
308
                   ForeignKey('timeserie.%s.id' % seriename,
309
                              ondelete='cascade'),
310
311
312
                   nullable=True,
                   unique=True,
                   index=True),
313
314
            schema='timeserie',
            extend_existing=True
315
316
        )

317
    def _make_ts_table(self, cn, name):
318
        tablename = self._ts_table_name(name)
319
        table = self._table_definition_for(name)
320
        table.create(cn)
321
        sql = schema.registry.insert().values(
322
323
            name=name,
            table_name=tablename)
324
        cn.execute(sql)
325
326
        return table

327
    def _get_ts_table(self, cn, name):
328
        reg = schema.registry
329
330
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
331
        tid = cn.execute(sql).scalar()
332
        if tid:
333
            return self._table_definition_for(name)
334

335
336
    # changeset handling

337
    def _newchangeset(self, cn, author, _insertion_date=None):
338
        table = schema.changeset
339
340
        sql = table.insert().values(
            author=author,
341
            insertion_date=_insertion_date or datetime.now())
342
        return cn.execute(sql).inserted_primary_key[0]
343

344
345
    def _latest_csid_for(self, cn, name):
        table = self._get_ts_table(cn, name)
Aurélien Campéas's avatar
Aurélien Campéas committed
346
        sql = select([func.max(table.c.csid)])
347
        return cn.execute(sql).scalar()
348

349
    def _changeset_series(self, cn, csid):
350
        cset_serie = schema.changeset_series
351
352
353
        sql = select([cset_serie.c.serie]
        ).where(cset_serie.c.csid == csid)

354
        return [seriename for seriename, in cn.execute(sql).fetchall()]
355
356
357

    # insertion handling

358
    def _get_tip_id(self, cn, table):
359
        sql = select([func.max(table.c.id)])
360
        return cn.execute(sql).scalar()
361

362
363
364
    def _complete_insertion_value(self, value, extra_scalars):
        pass

365
    def _finalize_insertion(self, cn, csid, name):
366
        table = schema.changeset_series
367
368
369
370
        sql = table.insert().values(
            csid=csid,
            serie=name
        )
371
        cn.execute(sql)
372

373
374
    # snapshot handling

375
376
    def _purge_snapshot_at(self, cn, table, diffid):
        cn.execute(
377
378
379
380
381
            table.update(
            ).where(table.c.id == diffid
            ).values(snapshot=None)
        )

382
    def _validate_type(self, oldts, newts, name):
383
384
385
        if (oldts is None or
            oldts.isnull().all() or
            newts.isnull().all()):
386
387
388
389
390
391
392
393
            return
        old_type = oldts.dtype
        new_type = newts.dtype
        if new_type != old_type:
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
                name, new_type, old_type)
            raise Exception(m)

394
395
    def _compute_diff_and_newsnapshot(self, cn, table, newts, **extra_scalars):
        snapshot = self._build_snapshot_upto(cn, table)
396
        self._validate_type(snapshot, newts, table.name)
397
398
399
400
401
402
403
404
405
        diff = self._compute_diff(snapshot, newts)

        if len(diff) == 0:
            return None, None

        # full state computation & insertion
        newsnapshot = self._apply_diff(snapshot, diff)
        return diff, newsnapshot

406
    def _find_snapshot(self, cn, table, qfilter=(), column='snapshot'):
407
        cset = schema.changeset
408
409
410
        sql = select([table.c.id, table.c[column]]
        ).order_by(desc(table.c.id)
        ).limit(1
411
        ).where(table.c[column] != None)
412
413

        if qfilter:
414
            sql = sql.where(table.c.csid == cset.c.id)
415
416
417
418
            for filtercb in qfilter:
                sql = sql.where(filtercb(cset, table))

        try:
419
            snapid, snapdata = cn.execute(sql).fetchone()
420
421
        except TypeError:
            return None, None
422
        return snapid, self._deserialize(snapdata, table.name)
423

424
425
    def _build_snapshot_upto(self, cn, table, qfilter=()):
        snapid, snapshot = self._find_snapshot(cn, table, qfilter)
426
427
428
        if snapid is None:
            return None

429
        cset = schema.changeset
430
        sql = select([table.c.id,
431
                      table.c.diff,
432
                      table.c.parent,
433
434
                      cset.c.insertion_date]
        ).order_by(table.c.id
435
        ).where(table.c.id > snapid)
436

437
438
439
440
        if qfilter:
            sql = sql.where(table.c.csid == cset.c.id)
            for filtercb in qfilter:
                sql = sql.where(filtercb(cset, table))
441

442
        alldiffs = pd.read_sql(sql, cn)
443
444

        if len(alldiffs) == 0:
445
            return snapshot
446

447
        # initial ts
448
449
        ts = snapshot
        for _, row in alldiffs.iterrows():
450
            diff = self._deserialize(row['diff'], table.name)
451
            ts = self._apply_diff(ts, diff)
452
453
        assert ts.index.dtype.name == 'datetime64[ns]' or len(ts) == 0
        return ts
454
455
456

    # diff handling

457
458
    def _diff(self, cn, csetid, name):
        table = self._get_ts_table(cn, name)
459
460
461
462
463
464
465
        cset = schema.changeset

        def filtercset(sql):
            return sql.where(table.c.csid == cset.c.id
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
466
        tsid = cn.execute(sql).scalar()
467
468
469
470
471
472
473

        if tsid == 1:
            sql = select([table.c.snapshot])
        else:
            sql = select([table.c.diff])
        sql = filtercset(sql)

474
        return self._deserialize(cn.execute(sql).scalar(), name)
475

476
    def _compute_diff(self, fromts, tots):
477
478
        """Compute the difference between fromts and tots
        (like in tots - fromts).
479
480

        """
481
        if fromts is None:
482
            return tots
483
484
485
        fromts = fromts[~fromts.isnull()]
        if not len(fromts):
            return tots
Aurélien Campéas's avatar
Aurélien Campéas committed
486

487
488
489
490
491
        mask_overlap = tots.index.isin(fromts.index)
        fromts_overlap = fromts[tots.index[mask_overlap]]
        tots_overlap = tots[mask_overlap]

        if fromts.dtype == 'float64':
492
            mask_equal = np.isclose(fromts_overlap, tots_overlap,
493
                                    rtol=0, atol=self._precision)
494
495
496
        else:
            mask_equal = fromts_overlap == tots_overlap

497
498
499
        mask_na_equal = fromts_overlap.isnull() & tots_overlap.isnull()
        mask_equal = mask_equal | mask_na_equal

500
501
        diff_overlap = tots[mask_overlap][~mask_equal]
        diff_new = tots[~mask_overlap]
502
        diff_new = diff_new[~diff_new.isnull()]
503
        return pd.concat([diff_overlap, diff_new])
504
505
506

    def _apply_diff(self, base_ts, new_ts):
        """Produce a new ts using base_ts as a base and taking any
507
        intersecting and new values from new_ts.
508
509
510
511
512
513
514
515
516
517

        """
        if base_ts is None:
            return new_ts
        if new_ts is None:
            return base_ts
        result_ts = pd.Series([0.0], index=base_ts.index.union(new_ts.index))
        result_ts[base_ts.index] = base_ts
        result_ts[new_ts.index] = new_ts
        result_ts.sort_index(inplace=True)
518
        result_ts.name = base_ts.name
519
        return result_ts