This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathpresto.py
More file actions
141 lines (115 loc) · 4.75 KB
/
presto.py
File metadata and controls
141 lines (115 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from logging import raiseExceptions
import re
from .database_types import *
from .base import Database, import_helper, _query_conn
from .base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
TIMESTAMP_PRECISION_POS,
DEFAULT_DATETIME_PRECISION,
DEFAULT_NUMERIC_PRECISION,
)
@import_helper("presto")
def import_presto():
import prestodb
return prestodb
class Presto(Database):
default_schema = "public"
TYPE_CLASSES = {
# Timestamps
"timestamp with time zone": TimestampTZ,
"timestamp without time zone": Timestamp,
"timestamp": Timestamp,
# Numbers
"integer": Integer,
"bigint": Integer,
"real": Float,
"double": Float,
# Text
"varchar": Text,
}
ROUNDS_ON_PREC_LOSS = True
def __init__(self, **kw):
prestodb = import_presto()
if kw.get("schema"):
self.default_schema = kw.get("schema")
try:
# checks if user and password are missing when auth=basic
kw.get("auth") == "basic" and "user" in kw and "password" in kw
# if auth=basic, add basic authenticator for Presto
kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password"))
except:
raise KeyError("User or password cannot be missing if auth==basic")
if "cert" in kw: # if a certificate was specified in URI, verify session with cert
cert = kw.pop("cert")
self._conn = prestodb.dbapi.connect(**kw)
self._conn._http_session.verify = cert
else:
self._conn = prestodb.dbapi.connect(**kw)
def quote(self, s: str):
return f'"{s}"'
def md5_to_int(self, s: str) -> str:
return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))"
def to_string(self, s: str):
return f"cast({s} as varchar)"
def _query(self, sql_code: str) -> list:
"Uses the standard SQL cursor interface"
c = self._conn.cursor()
c.execute(sql_code)
if sql_code.lower().startswith("select"):
return c.fetchall()
# Required for the query to actually run 🤯
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
return c.fetchone()
def close(self):
self._conn.close()
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
# TODO
if coltype.rounds:
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
else:
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)
return (
f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)
def _parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
) -> ColType:
timestamp_regexps = {
r"timestamp\((\d)\)": Timestamp,
r"timestamp\((\d)\) with time zone": TimestampTZ,
}
for regexp, t_cls in timestamp_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return t_cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
for regexp, n_cls in number_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
prec, scale = map(int, m.groups())
return n_cls(scale)
string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
for regexp, n_cls in string_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
return n_cls()
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Trim doesn't work on CHAR type
return f"TRIM(CAST({value} AS VARCHAR))"