import dataclasses import enum import time from typing import Any, Dict, List, Optional, Type, Tuple import pydantic import requests class TestDataSourceStatus(str, enum.Enum): SUCCESS = "ok" FAILED = "error" SKIP = "skip" UNKNOWN = "unknown" class TCloudApiDataSourceSchema(pydantic.BaseModel): title: str properties: Dict[str, Dict[str, Any]] required: List[str] secret: List[str] class TCloudApiDataSourceConfigSchema(pydantic.BaseModel): name: str db_type: str config_schema: TCloudApiDataSourceSchema class TCloudApiDataSource(pydantic.BaseModel): id: Optional[int] = None name: str type: str is_paused: Optional[bool] = False hidden: Optional[bool] = False temp_schema: Optional[str] = None disable_schema_indexing: Optional[bool] = False disable_profiling: Optional[bool] = False catalog_include_list: Optional[str] = None catalog_exclude_list: Optional[str] = None schema_indexing_schedule: Optional[str] = None schema_max_age_s: Optional[int] = None profile_schedule: Optional[str] = None profile_exclude_list: Optional[str] = None profile_include_list: Optional[str] = None discourage_manual_profiling: Optional[bool] = False lineage_schedule: Optional[str] = None float_tolerance: Optional[float] = 0.0 options: Optional[Dict[str, Any]] = None queue_name: Optional[str] = None scheduled_queue_name: Optional[str] = None groups: Optional[Dict[int, bool]] = None view_only: Optional[bool] = False created_from: Optional[str] = None source: Optional[str] = None max_allowed_connections: Optional[int] = None last_test: Optional[Any] = None secret_id: Optional[int] = None class TDsConfig(pydantic.BaseModel): name: str type: str temp_schema: str float_tolerance: float = 0.0 options: Dict[str, Any] disable_schema_indexing: bool = True disable_profiling: bool = True class TCloudApiDataDiff(pydantic.BaseModel): data_source1_id: int data_source2_id: int table1: List[str] table2: List[str] pk_columns: List[str] class TSummaryResultPrimaryKeyStats(pydantic.BaseModel): total_rows: Tuple[int, int] nulls: Tuple[int, int] dupes: Tuple[int, int] exclusives: Tuple[int, int] distincts: Tuple[int, int] class TSummaryResultColumnDiffStats(pydantic.BaseModel): column_name: str match: float class TSummaryResultValueStats(pydantic.BaseModel): total_rows: int rows_with_differences: int total_values: int compared_columns: int columns_with_differences: int columns_diff_stats: List[TSummaryResultColumnDiffStats] class TSummaryResultSchemaStats(pydantic.BaseModel): columns_mismatched: Tuple[int, int] column_type_mismatches: int column_reorders: int column_counts: Tuple[int, int] class TCloudApiDataDiffSummaryResult(pydantic.BaseModel): status: str pks: Optional[TSummaryResultPrimaryKeyStats] values: Optional[TSummaryResultValueStats] schema_: Optional[TSummaryResultSchemaStats] dependencies: Optional[Dict[str, Any]] @classmethod def from_orm(cls: Type['Model'], obj: Any) -> 'Model': pks = TSummaryResultPrimaryKeyStats(**obj['pks']) if 'pks' in obj else None values = TSummaryResultValueStats(**obj['values']) if 'values' in obj else None deps = obj['deps'] if 'deps' in obj else None schema = TSummaryResultSchemaStats(**obj['schema']) if 'schema' in obj else None return cls( status=obj['status'], pks=pks, values=values, schema_=schema, deps=deps, ) class TCloudDataSourceTestResult(pydantic.BaseModel): status: TestDataSourceStatus message: str outcome: str class TCloudApiDataSourceTestResult(pydantic.BaseModel): name: str status: str result: TCloudDataSourceTestResult @dataclasses.dataclass class DatafoldAPI: api_key: str host: str = "https://app.datafold.com" timeout: int = 30 def __post_init__(self): self.host = self.host.rstrip("/") self.headers = { "Authorization": f"Key {self.api_key}", "Content-Type": "application/json", } def make_get_request(self, url: str) -> Any: rv = requests.get(url=f"{self.host}/{url}", headers=self.headers, timeout=self.timeout) rv.raise_for_status() return rv def make_post_request(self, url: str, payload: Any) -> Any: rv = requests.post(url=f"{self.host}/{url}", headers=self.headers, json=payload, timeout=self.timeout) rv.raise_for_status() return rv def get_data_sources(self) -> List[TCloudApiDataSource]: rv = self.make_get_request(url="api/data_sources") rv.raise_for_status() return [TCloudApiDataSource(**item) for item in rv.json()] def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource: # TODO: replace an internal url by a public one rv = self.make_post_request(url="api/internal/data_sources", payload=config.dict()) return TCloudApiDataSource(**rv.json()) def get_data_source_schema_config(self) -> List[TCloudApiDataSourceConfigSchema]: # TODO: replace an internal url by a public one rv = self.make_get_request(url="api/internal/data_sources/types") return [ TCloudApiDataSourceConfigSchema( name=item["name"], db_type=item["type"], config_schema=TCloudApiDataSourceSchema( title=item["configuration_schema"]["title"], properties=item["configuration_schema"]["properties"], required=item["configuration_schema"]["required"], secret=item["configuration_schema"]["secret"], ), ) for item in rv.json() ] def create_data_diff(self, payload: TCloudApiDataDiff) -> int: rv = self.make_post_request(url="api/v1/datadiffs", payload=payload.dict()) return rv.json()["id"] def poll_data_diff_results(self, diff_id: int) -> TCloudApiDataDiffSummaryResult: summary_results = None start_time = time.time() sleep_interval = 5 # starts at 5 sec max_sleep_interval = 60 max_wait_time = 300 while not summary_results: response = self.make_get_request(url=f"api/v1/datadiffs/{diff_id}/summary_results") response_json = response.json() if response_json["status"] == "success": summary_results = response_json elif response_json["status"] == "failed": raise Exception(f"Diff failed: {str(response_json)}") if time.time() - start_time > max_wait_time: raise Exception("Timed out waiting for diff results") time.sleep(sleep_interval) sleep_interval = min(sleep_interval * 2, max_sleep_interval) return TCloudApiDataDiffSummaryResult.from_orm(summary_results) def test_data_source(self, data_source_id: int) -> int: # TODO: replace an internal url by a public one rv = self.make_post_request(f"api/internal/data_sources/{data_source_id}/test", {}) return rv.json()["job_id"] def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourceTestResult]: # TODO: replace an internal url by a public one rv = self.make_get_request(f"api/internal/data_sources/test/{job_id}") return [ TCloudApiDataSourceTestResult( name=item["step"], status=item["status"], result=TCloudDataSourceTestResult( status=item["result"]["code"].lower(), message=item["result"]["message"], outcome=item["result"]["outcome"], ), ) for item in rv.json()["results"] ]