This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathoracle.py
More file actions
130 lines (104 loc) · 4.49 KB
/
oracle.py
File metadata and controls
130 lines (104 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import re
from .database_types import *
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
from .base import DEFAULT_DATETIME_PRECISION, TIMESTAMP_PRECISION_POS
SESSION_TIME_ZONE = None # Changed by the tests
@import_helper("oracle")
def import_oracle():
import cx_Oracle
return cx_Oracle
class Oracle(ThreadedDatabase):
TYPE_CLASSES: Dict[str, type] = {
"NUMBER": Decimal,
"FLOAT": Float,
# Text
"CHAR": Text,
"NCHAR": Text,
"NVARCHAR2": Text,
"VARCHAR2": Text,
}
ROUNDS_ON_PREC_LOSS = True
def __init__(self, *, host, database, thread_count, **kw):
self.kwargs = dict(dsn="%s/%s" % (host, database) if database else host, **kw)
self.default_schema = kw.get("user")
super().__init__(thread_count=thread_count)
def create_connection(self):
self._oracle = import_oracle()
try:
c = self._oracle.connect(**self.kwargs)
if SESSION_TIME_ZONE:
c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'")
return c
except Exception as e:
raise ConnectError(*e.args) from e
def _query(self, sql_code: str):
try:
return super()._query(sql_code)
except self._oracle.DatabaseError as e:
raise QueryError(e)
def md5_to_int(self, s: str) -> str:
# standard_hash is faster than DBMS_CRYPTO.Hash
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')"
def quote(self, s: str):
return f"{s}"
def to_string(self, s: str):
return f"cast({s} as varchar(1024))"
def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)
return (
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'"
)
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
else:
if coltype.precision > 0:
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
else:
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"
def normalize_number(self, value: str, coltype: FractionalType) -> str:
# FM999.9990
format_str = "FM" + "9" * (38 - coltype.precision)
if coltype.precision:
format_str += "0." + "9" * (coltype.precision - 1) + "0"
return f"to_char({value}, '{format_str}')"
def _parse_type(
self,
table_name: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
regexps = {
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
r"TIMESTAMP\((\d)\)": Timestamp,
}
for regexp, t_cls in regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return t_cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)
return super()._parse_type(
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
)
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
if offset:
raise NotImplementedError("No support for OFFSET in query")
return f"FETCH NEXT {limit} ROWS ONLY"
def concat(self, l: List[str]) -> str:
joined_exprs = " || ".join(l)
return f"({joined_exprs})"
def timestamp_value(self, t: DbTime) -> str:
return "timestamp '%s'" % t.isoformat(" ")
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Cast is necessary for correct MD5 (trimming not enough)
return f"CAST(TRIM({value}) AS VARCHAR(36))"