util.py 7.23 KB
Newer Older
1
import os
2
3
import math
import zlib
4
5
import logging
import threading
6
7
8
9
import tempfile
import shutil
from contextlib import contextmanager
from pathlib import Path
10
11

import numpy as np
12
import pandas as pd
13
from pandas.api.types import is_datetime64tz_dtype
14
from sqlalchemy.engine import url
15
from sqlalchemy.engine.base import Engine
16
from inireader import reader
17
18


19
20
21
22
23
24
25
26
27
@contextmanager
def tempdir(suffix='', prefix='tmp'):
    tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix)
    try:
        yield Path(tmp)
    finally:
        shutil.rmtree(tmp)


28
29
30
31
32
33
34
35
36
37
38
39
40
def get_cfg_path():
    if 'TSHISTORYCFGPATH' is os.environ:
        cfgpath = Path(os.environ['TSHISTORYCFGPATH'])
        if cfgpath.exists():
            return cfgpath
    cfgpath = Path('tshistory.cfg')
    if cfgpath.exists():
        return cfgpath
    cfgpath = Path('~/tshistory.cfg').expanduser()
    if cfgpath.exists():
        return cfgpath


41
42
43
44
45
46
47
48
def find_dburi(something: str) -> str:
    try:
        url.make_url(something)
    except Exception:
        pass
    else:
        return something

49
50
51
52
    # lookup in the env, then in cwd, then in the home
    cfgpath = get_cfg_path()
    if not cfgpath:
        raise Exception('could not use nor look up the db uri')
53
54
55
56
57
58
59
60
61
62
63
64

    try:
        cfg = reader(cfgpath)
        return cfg['dburi'][something]
    except Exception as exc:
        raise Exception(
            f'could not find the `{something}` entry in the '
            f'[dburi] section of the `{cfgpath.resolve()}` '
            f'conf file (cause: {exc.__class__.__name__} -> {exc})'
        )


65
def tzaware_serie(ts):
66
    return is_datetime64tz_dtype(ts.index)
67
68


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def pruned_history(hist):
    if not hist:
        return hist
    idates = list(hist.keys())
    idate = idates[0]
    current = hist[idate]
    pruned = {
        idate: current
    }
    for idate in idates[1:]:
        newts = hist[idate]
        if not current.equals(newts):
            pruned[idate] = newts
            current = newts
    return pruned


86
def start_end(ts, notz=True):
87
88
89
    ts = ts.dropna()
    if not len(ts):
        return None, None
90
91
    start = ts.index.min()
    end = ts.index.max()
92
    if start.tzinfo is not None and notz:
93
94
95
96
97
98
99
        assert end.tzinfo is not None
        start = start.tz_convert('UTC').replace(tzinfo=None)
        end = end.tz_convert('UTC').replace(tzinfo=None)
    return start, end


def closed_overlaps(fromdate, todate):
100
101
    fromdate = "'-infinity'" if fromdate is None else '%(fromdate)s'
    todate = "'infinity'" if todate is None else '%(todate)s'
102
    return f'({fromdate}, {todate}) overlaps (tsstart, tsend + interval \'1 microsecond\')'
103
104


105
106
107
108
109
def inject_in_index(serie, revdate):
    mindex = [(revdate, valuestamp) for valuestamp in serie.index]
    serie.index = pd.MultiIndex.from_tuples(mindex, names=[
        'insertion_date', 'value_date']
    )
110
111
112
113
114
115
116
117
118
119


def num2float(pdobj):
    # get a Series or a Dataframe column
    if str(pdobj.dtype).startswith('int'):
        return pdobj.astype('float64')
    return pdobj


def tojson(ts, precision=1e-14):
120
121
    return ts.to_json(date_format='iso',
                      double_precision=-int(math.log10(precision)))
122
123


124
125
126
127
128
def fromjson(jsonb, tsname, tzaware=False):
    series = _fromjson(jsonb, tsname).fillna(value=np.nan)
    if tzaware:
        series.index = series.index.tz_localize('utc')
    return series
129
130
131
132
133
134
135
136


def _fromjson(jsonb, tsname):
    if jsonb == '{}':
        return pd.Series(name=tsname)

    result = pd.read_json(jsonb, typ='series', dtype=False)
    result.name = tsname
137
138
    result = num2float(result)
    return result
139
140
141
142
143
144
145
146
147
148


class SeriesServices(object):
    _precision = 1e-14

    # diff handling

    def patch(self, base, diff):
        assert base is not None
        assert diff is not None
149
150
        newindex = base.index.union(diff.index).sort_values()
        patched = pd.Series([0] * len(newindex), index=newindex)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        patched[base.index] = base
        patched[diff.index] = diff
        patched.name = base.name
        return patched

    def diff(self, base, other):
        if base is None:
            return other
        base = base[~base.isnull()]
        if not len(base):
            return other

        mask_overlap = other.index.isin(base.index)
        base_overlap = base[other.index[mask_overlap]]
        other_overlap = other[mask_overlap]

        if base.dtype == 'float64':
            mask_equal = np.isclose(base_overlap, other_overlap,
                                    rtol=0, atol=self._precision)
        else:
            mask_equal = base_overlap == other_overlap

        mask_na_equal = base_overlap.isnull() & other_overlap.isnull()
        mask_equal = mask_equal | mask_na_equal

        diff_overlap = other[mask_overlap][~mask_equal]
        diff_new = other[~mask_overlap]
        diff_new = diff_new[~diff_new.isnull()]
        return pd.concat([diff_overlap, diff_new])


182
def delete_series(engine, series, namespace='tsh'):
183
184
    from tshistory.tsio import timeseries
    tsh = timeseries(namespace=namespace)
185
186
187
188
189
190
191
192

    for name in series:
        with engine.begin() as cn:
            if not tsh.exists(cn, name):
                print('skipping unknown', name)
                continue
            print('delete', name)
            tsh.delete(cn, name)
193
194
195
196
197
198
199
200


def threadpool(maxthreads):
    L = logging.getLogger('parallel')

    def run(func, argslist):
        count = 0
        threads = []
201
        L.debug('// run %s %s', func.__name__, len(argslist))
202
203
204
205
206

        # initial threads
        for count, args in enumerate(argslist, start=1):
            th = threading.Thread(target=func, args=args)
            threads.append(th)
207
            L.debug('// start thread %s', th.name)
208
209
210
211
212
213
214
215
216
217
            th.daemon = True
            th.start()
            if count == maxthreads:
                break

        while threads:
            for th in threads[:]:
                th.join(1. / maxthreads)
                if not th.is_alive():
                    threads.remove(th)
218
                    L.debug('// thread %s exited, %s remaining', th.name, len(threads))
219
220
221
                    if count < len(argslist):
                        newth = threading.Thread(target=func, args=argslist[count])
                        threads.append(newth)
222
                        L.debug('// thread %s started', newth.name)
223
224
225
226
227
                        newth.daemon = True
                        newth.start()
                        count += 1

    return run
228
229


230
231
232
233
def tx(func):
    " a decorator to check that the first method argument is a transaction "
    def check_tx_and_call(self, cn, *a, **kw):
        # safety belt to make sure important api points are tx-safe
234
235
        if not isinstance(cn, Engine):
            if not cn.in_transaction():
236
                raise TypeError('You must use a transaction object')
237
238
239
        else:
            with cn.begin() as txcn:
                return func(self, txcn, *a, **kw)
240
241

        return func(self, cn, *a, **kw)
242
    check_tx_and_call.__name__ = func.__name__
243
244
245
    return check_tx_and_call


246
247
248
249
250
251
class unilist(list):
    " a list which refuses duplicates "

    def append(self, element):
        assert element not in self
        super().append(element)