This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathsnowflake.py
More file actions
182 lines (147 loc) · 6.1 KB
/
snowflake.py
File metadata and controls
182 lines (147 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from typing import Any, Union, List
import logging
import attrs
from data_diff.abcs.database_types import (
Timestamp,
TimestampTZ,
Decimal,
Float,
Text,
FractionalType,
TemporalType,
DbPath,
Boolean,
Date,
)
from data_diff.abcs.mixins import (
AbstractMixin_MD5,
AbstractMixin_NormalizeValue,
)
from data_diff.abcs.compiler import Compilable
from data_diff.queries.api import table, this, SKIP, code
from data_diff.databases.base import (
BaseDialect,
ConnectError,
Database,
import_helper,
CHECKSUM_MASK,
ThreadLocalInterpreter,
CHECKSUM_OFFSET,
)
@import_helper("snowflake")
def import_snowflake():
import snowflake.connector
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
return snowflake, serialization, default_backend
class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
name = "Snowflake"
ROUNDS_ON_PREC_LOSS = False
TYPE_CLASSES = {
# Timestamps
"TIMESTAMP_NTZ": Timestamp,
"TIMESTAMP_LTZ": Timestamp,
"TIMESTAMP_TZ": TimestampTZ,
"DATE": Date,
# Numbers
"NUMBER": Decimal,
"FLOAT": Float,
# Text
"TEXT": Text,
# Boolean
"BOOLEAN": Boolean,
}
def explain_as_text(self, query: str) -> str:
return f"EXPLAIN USING TEXT {query}"
def quote(self, s: str):
return f'"{s}"'
def to_string(self, s: str):
return f"cast({s} as string)"
def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIMEZONE = 'UTC'"
def optimizer_hints(self, hints: str) -> str:
raise NotImplementedError("Optimizer hints not yet implemented in snowflake")
def type_repr(self, t) -> str:
if isinstance(t, TimestampTZ):
return f"timestamp_tz({t.precision})"
return super().type_repr(t)
def md5_as_int(self, s: str) -> str:
return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}"
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))"
else:
timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))"
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"{value}::int")
@attrs.define(frozen=False, init=False, kw_only=True)
class Snowflake(Database):
dialect = Dialect()
CONNECT_URI_HELP = "snowflake://<user>:<password>@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>"
CONNECT_URI_PARAMS = ["database", "schema"]
CONNECT_URI_KWPARAMS = ["warehouse"]
_conn: Any
def __init__(self, *, schema: str, **kw):
super().__init__()
snowflake, serialization, default_backend = import_snowflake()
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
# Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state
# It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145
logging.getLogger("snowflake.connector.network").disabled = True
assert '"' not in schema, "Schema name should not contain quotes!"
# If a private key is used, read it from the specified path and pass it as "private_key" to the connector.
if "key" in kw:
with open(kw.get("key"), "rb") as key:
if "password" in kw:
raise ConnectError("Cannot use password and key at the same time")
if kw.get("private_key_passphrase"):
encoded_passphrase = kw.get("private_key_passphrase").encode()
else:
encoded_passphrase = None
p_key = serialization.load_pem_private_key(
key.read(),
password=encoded_passphrase,
backend=default_backend(),
)
kw["private_key"] = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw)
self.default_schema = schema
def close(self):
super().close()
self._conn.close()
def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
"Uses the standard SQL cursor interface"
return self._query_conn(self._conn, sql_code)
def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
database, schema, name = self._normalize_table_path(path)
info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, database)
return (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)
def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
return None, self.default_schema, path[0]
elif len(path) == 2:
return None, path[0], path[1]
elif len(path) == 3:
return path
raise ValueError(
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
)
@property
def is_autocommit(self) -> bool:
return True
def query_table_unique_columns(self, path: DbPath) -> List[str]:
return []