# Copyright 2023 Cognite AS
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import re
import sys
from enum import Enum
from hashlib import sha256
from typing import Any, Callable, Dict, Generic, Iterable, Optional, TextIO, Type, TypeVar, Union
import dacite
import yaml
from yaml.scanner import ScannerError
from cognite.extractorutils.configtools._util import _to_snake_case
from cognite.extractorutils.configtools.elements import BaseConfig, ConfigType, TimeIntervalConfig, _BaseConfig
from cognite.extractorutils.exceptions import InvalidConfigError
_logger = logging.getLogger(__name__)
CustomConfigClass = TypeVar("CustomConfigClass", bound=BaseConfig)
def _load_yaml(
source: Union[TextIO, str],
config_type: Type[CustomConfigClass],
case_style: str = "hyphen",
expand_envvars=True,
dict_manipulator: Callable[[Dict[str, Any]], Dict[str, Any]] = lambda x: x,
) -> CustomConfigClass:
def env_constructor(_: yaml.SafeLoader, node):
bool_values = {
"true": True,
"false": False,
}
expanded_value = os.path.expandvars(node.value)
return bool_values.get(expanded_value.lower(), expanded_value)
class EnvLoader(yaml.SafeLoader):
pass
EnvLoader.add_implicit_resolver("!env", re.compile(r"\$\{([^}^{]+)\}"), None)
EnvLoader.add_constructor("!env", env_constructor)
loader = EnvLoader if expand_envvars else yaml.SafeLoader
# Safe to use load instead of safe_load since both loader classes are based on SafeLoader
try:
config_dict = yaml.load(source, Loader=loader)
except ScannerError as e:
location = e.problem_mark or e.context_mark
formatted_location = f" at line {location.line+1}, column {location.column+1}" if location is not None else ""
cause = e.problem or e.context
raise InvalidConfigError(f"Invalid YAML{formatted_location}: {cause or ''}") from e
config_dict = dict_manipulator(config_dict)
config_dict = _to_snake_case(config_dict, case_style)
try:
config = dacite.from_dict(
data=config_dict, data_class=config_type, config=dacite.Config(strict=True, cast=[Enum, TimeIntervalConfig])
)
except dacite.UnexpectedDataError as e:
unknowns = [f'"{k.replace("_", "-") if case_style == "hyphen" else k}"' for k in e.keys]
raise InvalidConfigError(f"Unknown config parameter{'s' if len(unknowns) > 1 else ''} {', '.join(unknowns)}")
except (dacite.WrongTypeError, dacite.MissingValueError, dacite.UnionMatchError) as e:
path = e.field_path.replace("_", "-") if case_style == "hyphen" else e.field_path
def name(type_: Type) -> str:
return type_.__name__ if hasattr(type_, "__name__") else str(type_)
def all_types(type_: Type) -> Iterable[Type]:
return type_.__args__ if hasattr(type_, "__args__") else [type_]
if isinstance(e, (dacite.WrongTypeError, dacite.UnionMatchError)) and e.value is not None:
got_type = name(type(e.value))
need_type = ", ".join(name(t) for t in all_types(e.field_type))
raise InvalidConfigError(
f'Wrong type for field "{path}" - got "{e.value}" of type {got_type} instead of {need_type}'
)
raise InvalidConfigError(f'Missing mandatory field "{path}"')
except dacite.ForwardReferenceError as e:
raise ValueError(f"Invalid config class: {str(e)}")
config._file_hash = sha256(json.dumps(config_dict).encode("utf-8")).hexdigest()
return config
[docs]def load_yaml(
source: Union[TextIO, str], config_type: Type[CustomConfigClass], case_style: str = "hyphen", expand_envvars=True
) -> CustomConfigClass:
"""
Read a YAML file, and create a config object based on its contents.
Args:
source: Input stream (as returned by open(...)) or string containing YAML.
config_type: Class of config type (i.e. your custom subclass of BaseConfig).
case_style: Casing convention of config file. Valid options are 'snake', 'hyphen' or 'camel'. Should be
'hyphen'.
expand_envvars: Substitute values with the pattern ${VAR} with the content of the environment variable VAR
Returns:
An initialized config object.
Raises:
InvalidConfigError: If any config field is given as an invalid type, is missing or is unknown
"""
return _load_yaml(source=source, config_type=config_type, case_style=case_style, expand_envvars=expand_envvars)
class ConfigResolver(Generic[CustomConfigClass]):
def __init__(self, config_path: str, config_type: Type[CustomConfigClass]):
self.config_path = config_path
self.config_type = config_type
self._config: Optional[CustomConfigClass] = None
self._next_config: Optional[CustomConfigClass] = None
def _reload_file(self):
with open(self.config_path, "r") as stream:
self._config_text = stream.read()
@property
def is_remote(self) -> bool:
raw_config_type = yaml.safe_load(self._config_text).get("type")
if raw_config_type is None:
_logger.warning("No config type specified, default to local")
raw_config_type = "local"
config_type = ConfigType(raw_config_type)
return config_type == ConfigType.REMOTE
@property
def has_changed(self) -> bool:
try:
self._resolve_config()
except Exception as e:
_logger.exception("Failed to reload configuration file")
return False
return self._config._file_hash != self._next_config._file_hash
@property
def config(self) -> CustomConfigClass:
if self._config is None:
self._resolve_config()
self.accept_new_config()
return self._config
def accept_new_config(self) -> None:
self._config = self._next_config
@classmethod
def from_cli(
cls, name: str, description: str, version: str, config_type: Type[CustomConfigClass]
) -> "ConfigResolver":
argument_parser = argparse.ArgumentParser(sys.argv[0], description=description)
argument_parser.add_argument(
"config", nargs=1, type=str, help="The YAML file containing configuration for the extractor."
)
argument_parser.add_argument("-v", "--version", action="version", version=f"{name} v{version}")
args = argument_parser.parse_args()
return cls(args.config[0], config_type)
def _inject_cognite(self, local_part: _BaseConfig, remote_part: Dict[str, Any]) -> Dict[str, Any]:
if "cognite" not in remote_part:
remote_part["cognite"] = {}
remote_part["cognite"]["idp-authentication"] = {
"client_id": local_part.cognite.idp_authentication.client_id,
"scopes": local_part.cognite.idp_authentication.scopes,
"secret": local_part.cognite.idp_authentication.secret,
"tenant": local_part.cognite.idp_authentication.tenant,
"token_url": local_part.cognite.idp_authentication.token_url,
"resource": local_part.cognite.idp_authentication.resource,
"authority": local_part.cognite.idp_authentication.authority,
}
if local_part.cognite.host is not None:
remote_part["cognite"]["host"] = local_part.cognite.host
remote_part["cognite"]["project"] = local_part.cognite.project
remote_part["cognite"]["extraction-pipeline"] = {}
remote_part["cognite"]["extraction-pipeline"]["id"] = local_part.cognite.extraction_pipeline.id
remote_part["cognite"]["extraction-pipeline"][
"external_id"
] = local_part.cognite.extraction_pipeline.external_id
return remote_part
def _resolve_config(self) -> None:
self._reload_file()
if self.is_remote:
_logger.debug("Loading remote config file")
tmp_config: _BaseConfig = load_yaml(self._config_text, _BaseConfig)
client = tmp_config.cognite.get_cognite_client("config_resolver")
response = client.extraction_pipelines.config.retrieve(
tmp_config.cognite.get_extraction_pipeline(client).external_id
)
self._next_config = _load_yaml(
source=response.config,
config_type=self.config_type,
dict_manipulator=lambda d: self._inject_cognite(tmp_config, d),
)
else:
_logger.debug("Loading local config file")
self._next_config = load_yaml(self._config_text, self.config_type)