tsio.py 26.1 KB
Newer Older
1
from datetime import datetime
2
from contextlib import contextmanager
3
import logging
4
import zlib
5
import math
6
7

import pandas as pd
8
from pandas.api.types import is_datetimetz
9
10
import numpy as np

11
from sqlalchemy import Table, Column, Integer, ForeignKey
Aurélien Campéas's avatar
Aurélien Campéas committed
12
from sqlalchemy.sql.expression import select, func, desc
13
from sqlalchemy.dialects.postgresql import BYTEA
14

15
from tshistory.schema import tsschema
16
17


18
L = logging.getLogger('tshistory.tsio')
19
20


21
def tojson(ts):
22
    if not isinstance(ts.index, pd.MultiIndex):
23
24
        return ts.to_json(date_format='iso',
                          double_precision=-int(math.log10(TimeSerie._precision)))
25

26
27
28
    # multi index case
    return ts.to_frame().reset_index().to_json(date_format='iso')

Aurélien Campéas's avatar
Aurélien Campéas committed
29

30
31
32
33
34
def num2float(pdobj):
    # get a Series or a Dataframe column
    if str(pdobj.dtype).startswith('int'):
        return pdobj.astype('float64')
    return pdobj
35

Aurélien Campéas's avatar
Aurélien Campéas committed
36

37
def fromjson(jsonb, tsname):
38
39
40
41
    return _fromjson(jsonb, tsname).fillna(value=np.nan)


def _fromjson(jsonb, tsname):
42
43
44
    if jsonb == '{}':
        return pd.Series(name=tsname)

45
    result = pd.read_json(jsonb, typ='series', dtype=False)
46
    result.name = tsname
47
    if isinstance(result.index, pd.DatetimeIndex):
48
        result = num2float(result)
49
50
51
52
53
54
55
56
        return result

    # multi index case
    columns = result.index.values.tolist()
    columns.remove(tsname)
    result = pd.read_json(jsonb, typ='frame',
                          convert_dates=columns)
    result.set_index(sorted(columns), inplace=True)
Aurélien Campéas's avatar
Aurélien Campéas committed
57
    return num2float(result.iloc[:, 0])  # get a Series object
58
59


60
61
62
63
64
65
66
67
68
69
70
71
def tzaware_serie(ts):
    if isinstance(ts.index, pd.MultiIndex):
        tzaware = [is_datetimetz(ts.index.get_level_values(idx_name))
                   for idx_name in ts.index.names]
        assert all(tzaware) or not any(tzaware), (
            'all your indexes must be '
            'either tzaware or none of them'
        )
        return all(tzaware)
    return is_datetimetz(ts.index)


72
73
74
def subset(ts, fromdate, todate):
    if fromdate is None and todate is None:
        return ts
75
76
77
78
79
80
81
    if isinstance(fromdate, tuple):
        fromdate = fromdate[0]
    if isinstance(todate, tuple):
        todate = todate[0]
    if isinstance(ts.index, pd.MultiIndex):
        if not ts.index.lexsort_depth:
            ts.sort_index(inplace=True)
82
83
84
    return ts.loc[fromdate:todate]


85
86
87
88
89
90
91
92
93
94
95
96
97
def inject_in_index(serie, revdate):
    if isinstance(serie.index, pd.MultiIndex):
        mindex = [(revdate, *rest) for rest in serie.index]
        serie.index = pd.MultiIndex.from_tuples(mindex, names=[
            'insertion_date', *serie.index.names]
        )
        return
    mindex = [(revdate, valuestamp) for valuestamp in serie.index]
    serie.index = pd.MultiIndex.from_tuples(mindex, names=[
        'insertion_date', 'value_date']
    )


98
class TimeSerie(object):
99
    _csid = None
100
    _snapshot_interval = 10
101
    _precision = 1e-14
102
    namespace = 'tsh'
103
    schema = None
104
105
106

    def __init__(self, namespace='tsh'):
        self.namespace = namespace
107
108
        self.schema = tsschema(namespace)
        self.schema.define()
109
        self.metadatacache = {}
110
111
112

    # API : changeset, insert, get, delete
    @contextmanager
113
    def newchangeset(self, cn, author, _insertion_date=None):
114
115
116
117
118
119
        """A context manager to allow insertion of several series within the
        same changeset identifier

        This allows to group changes to several series, hence
        producing a macro-change.

120
121
        _insertion_date is *only* provided for migration purposes and
        not part of the API.
122
        """
123
        assert self._csid is None
124
        self._csid = self._newchangeset(cn, author, _insertion_date)
125
        self._author = author
126
127
        yield
        del self._csid
128
        del self._author
129

130
    def insert(self, cn, newts, name, author=None, _insertion_date=None,
131
               extra_scalars={}):
132
        """Create a new revision of a given time series
133

134
        newts: pandas.Series with date index
135

136
        name: str unique identifier of the serie
137
138
139
140

        author: str free-form author name (mandatory, unless provided
        to the newchangeset context manager).

141
        """
142
143
        assert self._csid or author, 'author is mandatory'
        if self._csid and author:
144
145
            L.info('author r{} will not be used when in a changeset'.format(author))
            author = None
146
        assert isinstance(newts, pd.Series)
147
        assert not newts.index.duplicated().any()
148

149
        newts = num2float(newts)
150

151
        if not len(newts):
152
            return
153

154
155
156
157
        assert ('<M8[ns]' == newts.index.dtype or
                'datetime' in str(newts.index.dtype) or
                isinstance(newts.index, pd.MultiIndex))

158
        newts.name = name
159
        table = self._get_ts_table(cn, name)
160

161
162
163
164
        if isinstance(newts.index, pd.MultiIndex):
            # we impose an order to survive rountrips
            newts = newts.reorder_levels(sorted(newts.index.names))

165
166
        if table is None:
            # initial insertion
167
168
            if newts.isnull().all():
                return None
169
            newts = newts[~newts.isnull()]
170
            table = self._make_ts_table(cn, name, newts)
171
            csid = self._csid or self._newchangeset(cn, author, _insertion_date)
172
            value = {
173
                'csid': csid,
174
                'snapshot': self._serialize(newts),
175
            }
176
177
            # callback for extenders
            self._complete_insertion_value(value, extra_scalars)
178
179
            cn.execute(table.insert().values(value))
            self._finalize_insertion(cn, csid, name)
180
181
            L.info('first insertion of %s (size=%s) by %s',
                   name, len(newts), author or self._author)
182
            return newts
183

184
        diff, newsnapshot = self._compute_diff_and_newsnapshot(
185
            cn, table, newts, **extra_scalars
186
187
        )
        if diff is None:
188
189
            L.info('no difference in %s by %s (for ts of size %s)',
                   name, author or self._author, len(newts))
190
191
            return

192
        tip_id = self._get_tip_id(cn, table)
193
        csid = self._csid or self._newchangeset(cn, author, _insertion_date)
194
        value = {
195
            'csid': csid,
196
197
            'diff': self._serialize(diff),
            'snapshot': self._serialize(newsnapshot),
198
199
200
201
            'parent': tip_id,
        }
        # callback for extenders
        self._complete_insertion_value(value, extra_scalars)
202
203
        cn.execute(table.insert().values(value))
        self._finalize_insertion(cn, csid, name)
204

205
        if tip_id > 1 and tip_id % self._snapshot_interval:
206
            self._purge_snapshot_at(cn, table, tip_id)
207
208
        L.info('inserted diff (size=%s) for ts %s by %s',
               len(diff), name, author or self._author)
209
        return diff
210

211
212
    def get(self, cn, name, revision_date=None,
            from_value_date=None, to_value_date=None):
213
214
215
216
217
        """Compute and return the serie of a given name

        revision_date: datetime filter to get previous versions of the
        serie

218
        """
219
        table = self._get_ts_table(cn, name)
220
221
        if table is None:
            return
222

223
224
225
        qfilter = []
        if revision_date:
            qfilter.append(lambda cset, _: cset.c.insertion_date <= revision_date)
226
227
228
        current = self._build_snapshot_upto(cn, table, qfilter,
                                            from_value_date=from_value_date,
                                            to_value_date=to_value_date)
229

230
231
        if current is not None:
            current.name = name
232
            current = current[~current.isnull()]
233
        return current
234

235
    def metadata(self, cn, tsname):
236
237
238
239
240
241
242
243
244
245
246
        """Return metadata dict of timeserie."""
        if tsname in self.metadatacache:
            return self.metadatacache[tsname]
        reg = self.schema.registry
        sql = select([reg.c.metadata]).where(
            reg.c.name == tsname
        )
        meta = cn.execute(sql).scalar()
        self.metadatacache[tsname] = meta
        return meta

247
248
    def get_group(self, cn, name, revision_date=None):
        csid = self._latest_csid_for(cn, name)
249
250

        group = {}
251
252
        for seriename in self._changeset_series(cn, csid):
            serie = self.get(cn, seriename, revision_date)
253
254
255
256
            if serie is not None:
                group[seriename] = serie
        return group

257
258
    def get_history(self, cn, name,
                    from_insertion_date=None,
259
260
                    to_insertion_date=None,
                    from_value_date=None,
261
262
                    to_value_date=None,
                    diffmode=False):
263
264
265
266
        table = self._get_ts_table(cn, name)
        if table is None:
            return

267
        # compute diffs above the snapshot
268
269
270
271
272
273
274
275
276
277
278
        cset = self.schema.changeset
        diffsql = select([cset.c.id, cset.c.insertion_date, table.c.diff]
        ).order_by(cset.c.id
        ).where(table.c.csid == cset.c.id)

        if from_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date >= from_insertion_date)
        if to_insertion_date:
            diffsql = diffsql.where(cset.c.insertion_date <= to_insertion_date)

        diffs = cn.execute(diffsql).fetchall()
279
280
281
282
283
        if not diffs:
            # it's fine to ask for an insertion date range
            # where noting did happen, but you get nothing
            return

284
285
286
287
288
289
290
        if diffmode:
            series = []
            for csid, revdate, diff in diffs:
                if diff is None:  # we must fetch the initial snapshot
                    sql = select([table.c.snapshot]).where(table.c.csid == csid)
                    diff = cn.execute(sql).scalar()
                serie = subset(self._deserialize(diff, name), from_value_date, to_value_date)
291
                serie = self._ensure_tz_consistency(cn, serie)
292
                inject_in_index(serie, revdate)
293
294
295
296
297
                series.append(serie)
            series = pd.concat(series)
            series.name = name
            return series

298
299
        csid, revdate, diff_ = diffs[0]
        snapshot = self._build_snapshot_upto(cn, table, [
300
            lambda cset, _: cset.c.id <= csid
301
        ], from_value_date, to_value_date)
302

303
        series = [(revdate, subset(snapshot, from_value_date, to_value_date))]
304
        for csid_, revdate, diff in diffs[1:]:
305
306
            diff = subset(self._deserialize(diff, table.name),
                          from_value_date, to_value_date)
307
            diff = self._ensure_tz_consistency(cn, diff)
308

309
310
311
312
            serie = self._apply_diff(series[-1][1], diff)
            series.append((revdate, serie))

        for revdate, serie in series:
313
            inject_in_index(serie, revdate)
314
315
316
317

        serie = pd.concat([serie for revdate_, serie in series])
        serie.name = name
        return serie
318

319
320
321
    def exists(self, cn, name):
        return self._get_ts_table(cn, name) is not None

322
    def latest_insertion_date(self, cn, name):
323
        cset = self.schema.changeset
324
        tstable = self._get_ts_table(cn, name)
325
326
        sql = select([func.max(cset.c.insertion_date)]
        ).where(tstable.c.csid == cset.c.id)
327
        return cn.execute(sql).scalar()
328

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def changeset_at(self, cn, seriename, revdate, mode='strict'):
        assert mode in ('strict', 'before', 'after')
        cset = self.schema.changeset
        table = self._table_definition_for(seriename)
        sql = select([table.c.csid]).where(
            table.c.csid == cset.c.id
        )
        if mode == 'strict':
            sql = sql.where(cset.c.insertion_date == revdate)
        elif mode == 'before':
            sql = sql.where(cset.c.insertion_date <= revdate)
        else:
            sql = sql.where(cset.c.insertion_date >= revdate)
        return cn.execute(sql).scalar()

    def strip(self, cn, seriename, csid):
        logs = self.log(cn, fromrev=csid, names=(seriename,))
        assert logs

        # put stripping info in the metadata
        cset = self.schema.changeset
        cset_serie = self.schema.changeset_series
        for log in logs:
            # update changeset.metadata
            metadata = cn.execute(
                select([cset.c.metadata]).where(cset.c.id == log['rev'])
            ).scalar() or {}
            metadata['tshistory.info'] = 'got stripped from {}'.format(csid)
            sql = cset.update().where(cset.c.id == log['rev']
            ).values(metadata=metadata)
            cn.execute(sql)
            # delete changset_serie item
            sql = cset_serie.delete().where(
                cset_serie.c.csid == log['rev']
            ).where(
                cset_serie.c.serie == seriename
            )
            cn.execute(sql)

        # wipe the diffs
        table = self._table_definition_for(seriename)
        cn.execute(table.delete().where(table.c.csid == csid))
        # rebuild the top-level snapshot
        cstip = self._latest_csid_for(cn, seriename)
        if cn.execute(select([table.c.snapshot]).where(table.c.csid == cstip)).scalar() is None:
            snap = self._build_snapshot_upto(
                cn, table,
                qfilter=(lambda cset, _t: cset.c.id < csid,)
            )
            sql = table.update().where(
                table.c.csid == cstip
            ).values(
                snapshot=self._serialize(snap)
            )
            cn.execute(sql)

385
    def info(self, cn):
386
387
        """Gather global statistics on the current tshistory repository
        """
388
        sql = 'select count(*) from {}.registry'.format(self.namespace)
389
        stats = {'series count': cn.execute(sql).scalar()}
390
        sql = 'select max(id) from {}.changeset'.format(self.namespace)
391
        stats['changeset count'] = cn.execute(sql).scalar()
392
        sql = 'select distinct name from {}.registry order by name'.format(self.namespace)
393
        stats['serie names'] = [row for row, in cn.execute(sql).fetchall()]
394
395
        return stats

396
    def log(self, cn, limit=0, diff=False, names=None, authors=None,
397
            stripped=False,
398
399
            fromrev=None, torev=None,
            fromdate=None, todate=None):
400
401
402
403
        """Build a structure showing the history of all the series in the db,
        per changeset, in chronological order.
        """
        log = []
404
405
406
407
408
        cset, cset_series, reg = (
            self.schema.changeset,
            self.schema.changeset_series,
            self.schema.registry
        )
409

410
        sql = select([cset.c.id, cset.c.author, cset.c.insertion_date, cset.c.metadata]
411
        ).distinct().order_by(desc(cset.c.id))
412
413
414
415

        if limit:
            sql = sql.limit(limit)

416
417
418
        if names:
            sql = sql.where(reg.c.name.in_(names))

419
420
421
        if authors:
            sql = sql.where(cset.c.author.in_(authors))

422
423
424
425
426
427
        if fromrev:
            sql = sql.where(cset.c.id >= fromrev)

        if torev:
            sql = sql.where(cset.c.id <= torev)

428
429
430
431
432
433
        if fromdate:
            sql = sql.where(cset.c.insertion_date >= fromdate)

        if todate:
            sql = sql.where(cset.c.insertion_date <= todate)

434
435
436
437
438
439
        if stripped:
            # outerjoin to show dead things
            sql = sql.select_from(cset.outerjoin(cset_series))
        else:
            sql = sql.where(cset.c.id == cset_series.c.csid
            ).where(cset_series.c.serie == reg.c.name)
440

441
        rset = cn.execute(sql)
442
        for csetid, author, revdate, meta in rset.fetchall():
443
            log.append({'rev': csetid, 'author': author, 'date': revdate,
444
                        'meta': meta or {},
445
                        'names': self._changeset_series(cn, csetid)})
446
447
448

        if diff:
            for rev in log:
449
                rev['diff'] = {name: self._diff(cn, rev['rev'], name)
450
451
                               for name in rev['names']}

452
        log.sort(key=lambda rev: rev['rev'])
453
454
        return log

455
456
    # /API
    # Helpers
457

458
459
460
    # ts serialisation

    def _serialize(self, ts):
461
462
        if ts is None:
            return None
463
        return zlib.compress(tojson(ts).encode('utf-8'))
464
465

    def _deserialize(self, ts, name):
466
        return fromjson(zlib.decompress(ts).decode('utf-8'), name)
467

468
469
470
471
472
    def _ensure_tz_consistency(self, cn, ts):
        """Return timeserie with tz aware index or not depending on metadata
        tzaware.
        """
        assert ts.name is not None
473
        metadata = self.metadata(cn, ts.name)
474
        if metadata and metadata.get('tzaware', False):
475
476
477
            if isinstance(ts.index, pd.MultiIndex):
                for i in range(len(ts.index.levels)):
                    ts.index = ts.index.set_levels(
478
                        ts.index.levels[i].tz_localize('UTC'),
479
480
                        level=i)
                return ts
481
482
483
            return ts.tz_localize('UTC')
        return ts

484
    # serie table handling
485

486
487
    def _ts_table_name(self, seriename):
        # namespace.seriename
488
        return '{}.timeserie.{}'.format(self.namespace, seriename)
489

490
    def _table_definition_for(self, seriename):
491
        return Table(
492
            seriename, self.schema.meta,
493
            Column('id', Integer, primary_key=True),
494
495
            Column('csid', Integer,
                   ForeignKey('{}.changeset.id'.format(self.namespace)),
496
                   index=True, nullable=False),
497
            # constraint: there is either .diff or .snapshot
498
499
            Column('diff', BYTEA),
            Column('snapshot', BYTEA),
500
501
            Column('parent',
                   Integer,
502
503
                   ForeignKey('{}.timeserie.{}.id'.format(self.namespace,
                                                          seriename),
504
                              ondelete='cascade'),
505
506
507
                   nullable=True,
                   unique=True,
                   index=True),
508
            schema='{}.timeserie'.format(self.namespace),
509
            extend_existing=True
510
511
        )

512
    def _make_ts_table(self, cn, name, ts):
513
        tablename = self._ts_table_name(name)
514
        table = self._table_definition_for(name)
515
        table.create(cn)
516
517
        index = ts.index
        inames = [name for name in index.names if name]
518
        sql = self.schema.registry.insert().values(
519
            name=name,
520
            table_name=tablename,
521
522
523
524
525
526
            metadata={
                'tzaware': tzaware_serie(ts),
                'index_type': index.dtype.name,
                'index_names': inames,
                'value_type': ts.dtypes.name
            },
527
        )
528
        cn.execute(sql)
529
530
        return table

531
    def _get_ts_table(self, cn, name):
532
        reg = self.schema.registry
533
534
        tablename = self._ts_table_name(name)
        sql = reg.select().where(reg.c.table_name == tablename)
535
        tid = cn.execute(sql).scalar()
536
        if tid:
537
            return self._table_definition_for(name)
538

539
540
    # changeset handling

541
    def _newchangeset(self, cn, author, _insertion_date=None):
542
        table = self.schema.changeset
543
544
        sql = table.insert().values(
            author=author,
545
            insertion_date=_insertion_date or datetime.now())
546
        return cn.execute(sql).inserted_primary_key[0]
547

548
549
    def _latest_csid_for(self, cn, name):
        table = self._get_ts_table(cn, name)
Aurélien Campéas's avatar
Aurélien Campéas committed
550
        sql = select([func.max(table.c.csid)])
551
        return cn.execute(sql).scalar()
552

553
    def _changeset_series(self, cn, csid):
554
        cset_serie = self.schema.changeset_series
555
556
557
        sql = select([cset_serie.c.serie]
        ).where(cset_serie.c.csid == csid)

558
        return [seriename for seriename, in cn.execute(sql).fetchall()]
559
560
561

    # insertion handling

562
    def _get_tip_id(self, cn, table):
563
        " get the *local* id "
564
        sql = select([func.max(table.c.id)])
565
        return cn.execute(sql).scalar()
566

567
568
569
    def _complete_insertion_value(self, value, extra_scalars):
        pass

570
    def _finalize_insertion(self, cn, csid, name):
571
        table = self.schema.changeset_series
572
573
574
575
        sql = table.insert().values(
            csid=csid,
            serie=name
        )
576
        cn.execute(sql)
577

578
579
    # snapshot handling

580
581
    def _purge_snapshot_at(self, cn, table, diffid):
        cn.execute(
582
583
584
585
586
            table.update(
            ).where(table.c.id == diffid
            ).values(snapshot=None)
        )

587
588
589
    def _validate(self, cn, name, ts):
        if ts.isnull().all():
            # ts erasure
590
            return
591
592
593
        meta = self.metadata(cn, name)
        tstype = ts.dtype
        if tstype != meta['value_type']:
594
            m = 'Type error when inserting {}, new type is {}, type in base is {}'.format(
595
                name, tstype, meta['value_type'])
596
            raise Exception(m)
597
        if ts.index.dtype.name != meta['index_type']:
598
            raise Exception('Incompatible index types')
599
600
601
602
603
        inames = [name for name in ts.index.names if name]
        if inames != meta['index_names']:
            raise Exception('Incompatible multi indexes: {} vs {}'.format(
                meta['index_names'], inames)
            )
604

605
    def _compute_diff_and_newsnapshot(self, cn, table, newts, **extra_scalars):
606
        self._validate(cn, table.name, newts)
607
        snapshot = self._build_snapshot_upto(cn, table)
608
        assert snapshot is not None
609
610
611
612
613
614
615
616
617
        diff = self._compute_diff(snapshot, newts)

        if len(diff) == 0:
            return None, None

        # full state computation & insertion
        newsnapshot = self._apply_diff(snapshot, diff)
        return diff, newsnapshot

618
619
    def _find_snapshot(self, cn, table, qfilter=(), column='snapshot',
                       from_value_date=None, to_value_date=None):
620
        cset = self.schema.changeset
621
622
623
        sql = select([table.c.id, table.c[column]]
        ).order_by(desc(table.c.id)
        ).limit(1
624
625
        ).where(table.c[column] != None
        ).select_from(table.join(cset))
626
627

        if qfilter:
628
            sql = sql.where(table.c.csid <= cset.c.id)
629
630
631
632
            for filtercb in qfilter:
                sql = sql.where(filtercb(cset, table))

        try:
633
            snapid, snapdata = cn.execute(sql).fetchone()
634
635
            snapdata = subset(self._deserialize(snapdata, table.name),
                              from_value_date, to_value_date)
636
            snapdata = self._ensure_tz_consistency(cn, snapdata)
637
638
        except TypeError:
            return None, None
639
        return snapid, snapdata
640

641
642
    def _build_snapshot_upto(self, cn, table, qfilter=(),
                             from_value_date=None, to_value_date=None):
643
644
645
        snapid, snapshot = self._find_snapshot(cn, table, qfilter,
                                               from_value_date=from_value_date,
                                               to_value_date=to_value_date)
646
647
648
        if snapid is None:
            return None

649
        cset = self.schema.changeset
650
651
        # beware the potential cartesian product
        # between table & cset if there is no qfilter
652
        sql = select([table.c.id,
653
                      table.c.diff,
654
                      table.c.parent,
655
656
                      cset.c.insertion_date]
        ).order_by(table.c.id
657
        ).where(table.c.id > snapid)
658

659
660
661
662
        if qfilter:
            sql = sql.where(table.c.csid == cset.c.id)
            for filtercb in qfilter:
                sql = sql.where(filtercb(cset, table))
663

664
        alldiffs = pd.read_sql(sql, cn)
665
666

        if len(alldiffs) == 0:
667
            return snapshot
668

669
        # initial ts
670
        ts = self._deserialize(alldiffs.loc[0, 'diff'], table.name)
671
        ts = self._ensure_tz_consistency(cn, ts)
672
        for row in alldiffs.loc[1:].itertuples():
673
674
            diff = subset(self._deserialize(row.diff, table.name),
                          from_value_date, to_value_date)
675
            diff = self._ensure_tz_consistency(cn, diff)
676
            ts = self._apply_diff(ts, diff)
677
        ts = self._apply_diff(snapshot, ts)
678
679
        assert ts.index.dtype.name == 'datetime64[ns]' or len(ts) == 0
        return ts
680
681
682

    # diff handling

683
684
    def _diff(self, cn, csetid, name):
        table = self._get_ts_table(cn, name)
685
        cset = self.schema.changeset
686
687
688
689
690
691

        def filtercset(sql):
            return sql.where(table.c.csid == cset.c.id
            ).where(cset.c.id == csetid)

        sql = filtercset(select([table.c.id]))
692
        tsid = cn.execute(sql).scalar()
693
694
695
696
697
698
699

        if tsid == 1:
            sql = select([table.c.snapshot])
        else:
            sql = select([table.c.diff])
        sql = filtercset(sql)

700
701
        ts = self._deserialize(cn.execute(sql).scalar(), name)
        return self._ensure_tz_consistency(cn, ts)
702

703
    def _compute_diff(self, fromts, tots):
704
705
        """Compute the difference between fromts and tots
        (like in tots - fromts).
706
707

        """
708
        if fromts is None:
709
            return tots
710
711
712
        fromts = fromts[~fromts.isnull()]
        if not len(fromts):
            return tots
Aurélien Campéas's avatar
Aurélien Campéas committed
713

714
715
716
717
718
        mask_overlap = tots.index.isin(fromts.index)
        fromts_overlap = fromts[tots.index[mask_overlap]]
        tots_overlap = tots[mask_overlap]

        if fromts.dtype == 'float64':
719
            mask_equal = np.isclose(fromts_overlap, tots_overlap,
720
                                    rtol=0, atol=self._precision)
721
722
723
        else:
            mask_equal = fromts_overlap == tots_overlap

724
725
726
        mask_na_equal = fromts_overlap.isnull() & tots_overlap.isnull()
        mask_equal = mask_equal | mask_na_equal

727
728
        diff_overlap = tots[mask_overlap][~mask_equal]
        diff_new = tots[~mask_overlap]
729
        diff_new = diff_new[~diff_new.isnull()]
730
        return pd.concat([diff_overlap, diff_new])
731
732
733

    def _apply_diff(self, base_ts, new_ts):
        """Produce a new ts using base_ts as a base and taking any
734
        intersecting and new values from new_ts.
735
736
737
738
739
740
741
742
743
744

        """
        if base_ts is None:
            return new_ts
        if new_ts is None:
            return base_ts
        result_ts = pd.Series([0.0], index=base_ts.index.union(new_ts.index))
        result_ts[base_ts.index] = base_ts
        result_ts[new_ts.index] = new_ts
        result_ts.sort_index(inplace=True)
745
        result_ts.name = base_ts.name
746
        return result_ts