This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathdatabricks.py
More file actions
159 lines (123 loc) · 5.49 KB
/
databricks.py
File metadata and controls
159 lines (123 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import Dict, Sequence
import logging
from .database_types import (
Integer,
Float,
Decimal,
Timestamp,
Text,
TemporalType,
NumericType,
DbPath,
ColType,
UnknownColType,
)
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name
@import_helper(text="You can install it using 'pip install databricks-sql-connector'")
def import_databricks():
import databricks.sql
return databricks
class Dialect(BaseDialect):
name = "Databricks"
ROUNDS_ON_PREC_LOSS = True
TYPE_CLASSES = {
# Numbers
"INT": Integer,
"SMALLINT": Integer,
"TINYINT": Integer,
"BIGINT": Integer,
"FLOAT": Float,
"DOUBLE": Float,
"DECIMAL": Decimal,
# Timestamps
"TIMESTAMP": Timestamp,
# Text
"STRING": Text,
}
def quote(self, s: str):
return f"`{s}`"
def md5_as_int(self, s: str) -> str:
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))"
def to_string(self, s: str) -> str:
return f"cast({s} as string)"
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Databricks timestamp contains no more than 6 digits in precision"""
if coltype.rounds:
timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)"
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 1 due to wierd precision issues
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
class Databricks(ThreadedDatabase):
dialect = Dialect()
def __init__(self, *, thread_count, **kw):
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
self._args = kw
self.default_schema = kw.get('schema', 'hive_metastore')
super().__init__(thread_count=thread_count)
def create_connection(self):
databricks = import_databricks()
try:
return databricks.sql.connect(
server_hostname=self._args['server_hostname'],
http_path=self._args['http_path'],
access_token=self._args['access_token'],
catalog=self._args['catalog'],
)
except databricks.sql.exc.Error as e:
raise ConnectionError(*e.args) from e
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
# So, to obtain information about schema, we should use another approach.
conn = self.create_connection()
schema, table = self._normalize_table_path(path)
with conn.cursor() as cursor:
cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table)
try:
rows = cursor.fetchall()
except:
rows = None
finally:
conn.close()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
assert len(d) == len(rows)
return d
def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
):
accept = {i.lower() for i in filter_columns}
rows = [row for name, row in raw_schema.items() if name.lower() in accept]
resulted_rows = []
for row in rows:
row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1]
type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType)
if issubclass(type_cls, Integer):
row = (row[0], row_type, None, None, 0)
elif issubclass(type_cls, Float):
numeric_precision = self._convert_db_precision_to_digits(row[2])
row = (row[0], row_type, None, numeric_precision, None)
elif issubclass(type_cls, Decimal):
items = row[1][8:].rstrip(")").split(",")
numeric_precision, numeric_scale = int(items[0]), int(items[1])
row = (row[0], row_type, None, numeric_precision, numeric_scale)
elif issubclass(type_cls, Timestamp):
row = (row[0], row_type, row[2], None, None)
else:
row = (row[0], row_type, None, None, None)
resulted_rows.append(row)
col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows}
self._refine_coltypes(path, col_dict, where)
return col_dict
def parse_table_name(self, name: str) -> DbPath:
path = parse_table_name(name)
return self._normalize_table_path(path)
@property
def is_autocommit(self) -> bool:
return True