import hashlib
import json
from abc import ABC
from collections.abc import Iterable, Iterator
from types import TracebackType
from typing import Any
import orjson
from cognite.client import CogniteClient
from cognite.client.data_classes import Row
from cognite.client.exceptions import CogniteAPIError
from cognite.extractorutils._inner_util import _DecimalDecoder, _DecimalEncoder
from cognite.extractorutils.threading import CancellationToken
from cognite.extractorutils.util import cognite_exceptions, retry
from ._base import RETRIES, RETRY_BACKOFF_FACTOR, RETRY_DELAY, RETRY_MAX_DELAY, _BaseStateStore
[docs]
class AbstractHashStateStore(_BaseStateStore, ABC):
def __init__(
self,
save_interval: int | None = None,
trigger_log_level: str = "DEBUG",
thread_name: str | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
super().__init__(
save_interval=save_interval,
trigger_log_level=trigger_log_level,
thread_name=thread_name,
cancellation_token=cancellation_token,
)
self._local_state: dict[str, dict[str, str]] = {}
self._seen: set[str] = set()
[docs]
def get_state(self, external_id: str) -> str | None:
with self.lock:
return self._local_state.get(external_id, {}).get("digest")
def _hash_row(self, data: dict[str, Any]) -> str:
return hashlib.sha256(orjson.dumps(data, option=orjson.OPT_SORT_KEYS)).hexdigest()
[docs]
def set_state(self, external_id: str, data: dict[str, Any]) -> None:
with self.lock:
self._local_state[external_id] = {"digest": self._hash_row(data)}
[docs]
def has_changed(self, external_id: str, data: dict[str, Any]) -> bool:
with self.lock:
if external_id not in self._local_state:
return True
return self._hash_row(data) != self._local_state[external_id]["digest"]
def __getitem__(self, external_id: str) -> str | None:
return self.get_state(external_id)
def __setitem__(self, key: str, value: dict[str, Any]) -> None:
self.set_state(external_id=key, data=value)
def __contains__(self, external_id: str) -> bool:
return external_id in self._local_state
def __len__(self) -> int:
return len(self._local_state)
def __iter__(self) -> Iterator[str]:
with self.lock:
yield from self._local_state
[docs]
class RawHashStateStore(AbstractHashStateStore):
def __init__(
self,
cdf_client: CogniteClient,
database: str,
table: str,
save_interval: int | None = None,
trigger_log_level: str = "DEBUG",
thread_name: str | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
super().__init__(
save_interval=save_interval,
trigger_log_level=trigger_log_level,
thread_name=thread_name,
cancellation_token=cancellation_token,
)
self._cdf_client = cdf_client
self.database = database
self.table = table
[docs]
def synchronize(self) -> None:
@retry(
exceptions=cognite_exceptions(),
cancellation_token=self.cancellation_token,
tries=RETRIES,
delay=RETRY_DELAY,
max_delay=RETRY_MAX_DELAY,
backoff=RETRY_BACKOFF_FACTOR,
)
def impl() -> None:
"""
Upload local state store to CDF
"""
with self.lock:
self._cdf_client.raw.rows.insert(
db_name=self.database,
table_name=self.table,
row=self._local_state,
ensure_parent=True,
)
impl()
[docs]
def initialize(self, force: bool = False) -> None:
@retry(
exceptions=cognite_exceptions(),
cancellation_token=self.cancellation_token,
tries=RETRIES,
delay=RETRY_DELAY,
max_delay=RETRY_MAX_DELAY,
backoff=RETRY_BACKOFF_FACTOR,
)
def impl() -> None:
"""
Get all known states.
Args:
force: Enable re-initialization, ie overwrite when called multiple times
"""
if self._initialized and not force:
return
rows: Iterable[Row]
try:
rows = self._cdf_client.raw.rows.list(db_name=self.database, table_name=self.table, limit=None)
except CogniteAPIError as e:
if e.code == 404:
rows = []
else:
raise e
with self.lock:
self._local_state.clear()
for row in rows:
if row.key is None or row.columns is None:
self.logger.warning(f"None encountered in row: {str(row)}")
# should never happen, but type from sdk is optional
continue
state = row.columns.get("digest")
if state:
self._local_state[row.key] = {"digest": state}
self._initialized = True
impl()
def __enter__(self) -> "RawHashStateStore":
"""
Wraps around start method, for use as context manager
Returns:
self
"""
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""
Wraps around stop method, for use as context manager
Args:
exc_type: Exception type
exc_val: Exception value
exc_tb: Traceback
"""
self.stop()
[docs]
class LocalHashStateStore(AbstractHashStateStore):
def __init__(
self,
file_path: str,
save_interval: int | None = None,
trigger_log_level: str = "DEBUG",
thread_name: str | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
super().__init__(
save_interval=save_interval,
trigger_log_level=trigger_log_level,
thread_name=thread_name,
cancellation_token=cancellation_token,
)
self._file_path = file_path
[docs]
def initialize(self, force: bool = False) -> None:
"""
Load states from specified JSON file
Args:
force: Enable re-initialization, ie overwrite when called multiple times
"""
if self._initialized and not force:
return
with self.lock:
try:
with open(self._file_path) as f:
self._local_state = json.load(f, cls=_DecimalDecoder)
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in state store file: {str(e)}") from e
self._initialized = True
[docs]
def synchronize(self) -> None:
"""
Save states to specified JSON file
"""
with self.lock:
with open(self._file_path, "w") as f:
json.dump(self._local_state, f, cls=_DecimalEncoder)
def __enter__(self) -> "LocalHashStateStore":
"""
Wraps around start method, for use as context manager
Returns:
self
"""
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""
Wraps around stop method, for use as context manager
Args:
exc_type: Exception type
exc_val: Exception value
exc_tb: Traceback
"""
self.stop()