Source code for ayon_api.server_api

"""Server API.

Provides access to server API.

"""
from __future__ import annotations

import os
import re
import io
import json
import time
import logging
import platform
import uuid
from contextlib import contextmanager
import typing
from typing import Optional, Iterable, Generator, Any, Union

import requests

from .constants import (
    SERVER_RETRIES_ENV_KEY,
    DEFAULT_FOLDER_TYPE_FIELDS,
    DEFAULT_TASK_TYPE_FIELDS,
    DEFAULT_PROJECT_STATUSES_FIELDS,
    DEFAULT_PROJECT_TAGS_FIELDS,
    DEFAULT_PRODUCT_TYPE_FIELDS,
    DEFAULT_PROJECT_FIELDS,
    DEFAULT_FOLDER_FIELDS,
    DEFAULT_TASK_FIELDS,
    DEFAULT_PRODUCT_FIELDS,
    DEFAULT_VERSION_FIELDS,
    DEFAULT_REPRESENTATION_FIELDS,
    REPRESENTATION_FILES_FIELDS,
    DEFAULT_WORKFILE_INFO_FIELDS,
    DEFAULT_EVENT_FIELDS,
    DEFAULT_ACTIVITY_FIELDS,
    DEFAULT_USER_FIELDS,
    DEFAULT_ENTITY_LIST_FIELDS,
)
from .graphql import INTROSPECTION_QUERY
from .graphql_queries import users_graphql_query
from .exceptions import (
    FailedOperations,
    UnauthorizedError,
    AuthenticationError,
    ServerNotReached,
)
from .utils import (
    RequestType,
    RequestTypes,
    RestApiResponse,
    prepare_query_string,
    logout_from_server,
    create_entity_id,
    entity_data_json_default,
    failed_json_default,
    TransferProgress,
    get_default_timeout,
    get_default_settings_variant,
    get_default_site_id,
    NOT_SET,
    get_media_mime_type,
    get_machine_name,
    fill_own_attribs,
)
from ._api_helpers import (
    InstallersAPI,
    DependencyPackagesAPI,
    SecretsAPI,
    BundlesAddonsAPI,
    EventsAPI,
    AttributesAPI,
    ProjectsAPI,
    FoldersAPI,
    TasksAPI,
    ProductsAPI,
    VersionsAPI,
    RepresentationsAPI,
    WorkfilesAPI,
    ThumbnailsAPI,
    ActivitiesAPI,
    ActionsAPI,
    LinksAPI,
    ListsAPI,
)

if typing.TYPE_CHECKING:
    from .typing import (
        ServerVersion,
        AnyEntityDict,
        StreamType,
        BackgroundOperationTask,
    )

VERSION_REGEX = re.compile(
    r"(?P<major>0|[1-9]\d*)"
    r"\.(?P<minor>0|[1-9]\d*)"
    r"\.(?P<patch>0|[1-9]\d*)"
    r"(?:-(?P<prerelease>[a-zA-Z\d\-.]*))?"
    r"(?:\+(?P<buildmetadata>[a-zA-Z\d\-.]*))?"
)


[docs]class GraphQlResponse: """GraphQl response.""" def __init__(self, data): self.data = data self.errors = data.get("errors") def __len__(self): if self.errors: return 0 return 1 def __repr__(self): if self.errors: message = self.errors[0]["message"] return f"<{self.__class__.__name__} errors={message}>" return f"<{self.__class__.__name__}>"
class _AsUserStack: """Handle stack of users used over server api connection in service mode. ServerAPI can behave as other users if it is using special API key. Examples: >>> stack = _AsUserStack() >>> stack.set_default_username("DefaultName") >>> print(stack.username) DefaultName >>> with stack.as_user("Other1"): ... print(stack.username) ... with stack.as_user("Other2"): ... print(stack.username) ... print(stack.username) ... stack.clear() ... print(stack.username) Other1 Other2 Other1 None >>> print(stack.username) None >>> stack.set_default_username("DefaultName") >>> print(stack.username) DefaultName """ def __init__(self): self._users_by_id = {} self._user_ids = [] self._last_user = None self._default_user = None def clear(self): self._users_by_id = {} self._user_ids = [] self._last_user = None self._default_user = None @property def username(self) -> Optional[str]: # Use '_user_ids' for boolean check to have ability "unset" # default user if self._user_ids: return self._last_user return self._default_user def get_default_username(self) -> Optional[str]: return self._default_user def set_default_username(self, username: Optional[str] = None) -> None: self._default_user = username default_username = property(get_default_username, set_default_username) @contextmanager def as_user(self, username: Optional[str]) -> Generator[None, None, None]: self._last_user = username user_id = uuid.uuid4().hex self._user_ids.append(user_id) self._users_by_id[user_id] = username try: yield finally: self._users_by_id.pop(user_id, None) if not self._user_ids: return # First check if is the user id the last one was_last = self._user_ids[-1] == user_id # Remove id from variables if user_id in self._user_ids: self._user_ids.remove(user_id) if not was_last: return new_last_user = None if self._user_ids: new_last_user = self._users_by_id.get(self._user_ids[-1]) self._last_user = new_last_user
[docs]class ServerAPI( InstallersAPI, DependencyPackagesAPI, SecretsAPI, BundlesAddonsAPI, EventsAPI, AttributesAPI, ProjectsAPI, FoldersAPI, TasksAPI, ProductsAPI, VersionsAPI, RepresentationsAPI, WorkfilesAPI, ThumbnailsAPI, ActivitiesAPI, ActionsAPI, LinksAPI, ListsAPI, ): """Base handler of connection to server. Requires url to server which is used as base for api and graphql calls. Login cause that a session is used Args: base_url (str): Example: http://localhost:5000 token (Optional[str]): Access token (api key) to server. site_id (Optional[str]): Unique name of site. Should be the same when connection is created from the same machine under same user. client_version (Optional[str]): Version of client application (used in desktop client application). default_settings_variant (Optional[Literal["production", "staging"]]): Settings variant used by default if a method for settings won't get any (by default is 'production'). sender_type (Optional[str]): Sender type of requests. Used in server logs and propagated into events. sender (Optional[str]): Sender of requests, more specific than sender type (e.g. machine name). Used in server logs and propagated into events. ssl_verify (Optional[Union[bool, str]]): Verify SSL certificate Looks for env variable value ``AYON_CA_FILE`` by default. If not available then 'True' is used. cert (Optional[str]): Path to certificate file. Looks for env variable value ``AYON_CERT_FILE`` by default. create_session (Optional[bool]): Create session for connection if token is available. Default is True. timeout (Optional[float]): Timeout for requests. max_retries (Optional[int]): Number of retries for requests. """ _default_max_retries = 3 # 1 MB chunk by default # TODO find out if these are reasonable default value default_download_chunk_size = 1024 * 1024 default_upload_chunk_size = 1024 * 1024 def __init__( self, base_url: str, token: Optional[str] = None, site_id: Optional[str] = NOT_SET, client_version: Optional[str] = None, default_settings_variant: Optional[str] = None, sender_type: Optional[str] = None, sender: Optional[str] = None, ssl_verify: Optional[Union[bool, str]] = None, cert: Optional[str] = None, create_session: bool = True, timeout: Optional[float] = None, max_retries: Optional[int] = None, ): if not base_url: raise ValueError(f"Invalid server URL {str(base_url)}") base_url = base_url.rstrip("/") self._base_url: str = base_url self._rest_url: str = f"{base_url}/api" self._graphql_url: str = f"{base_url}/graphql" self._log: logging.Logger = logging.getLogger(self.__class__.__name__) self._access_token: Optional[str] = token # Allow to have 'site_id' to 'None' if site_id is NOT_SET: site_id = get_default_site_id() self._site_id: Optional[str] = site_id self._client_version: Optional[str] = client_version self._default_settings_variant: str = ( default_settings_variant or get_default_settings_variant() ) self._sender: Optional[str] = sender self._sender_type: Optional[str] = sender_type self._timeout: float = 0.0 self._max_retries: int = 0 # Set timeout and max retries based on passed values self.set_timeout(timeout) self.set_max_retries(max_retries) if ssl_verify is None: # Custom AYON env variable for CA file or 'True' # - that should cover most default behaviors in 'requests' # with 'certifi' ssl_verify = os.environ.get("AYON_CA_FILE") or True if cert is None: cert = os.environ.get("AYON_CERT_FILE") self._ssl_verify = ssl_verify self._cert = cert self._access_token_is_service = None self._token_is_valid = None self._token_validation_started = False self._server_available = None self._server_version = None self._server_version_tuple = None self._graphql_allows_traits_in_representations: Optional[bool] = None self._product_base_type_supported = None self._session = None self._base_functions_mapping = { RequestTypes.get: requests.get, RequestTypes.post: requests.post, RequestTypes.put: requests.put, RequestTypes.patch: requests.patch, RequestTypes.delete: requests.delete } self._session_functions_mapping = {} # Attributes cache self._attributes_schema = None self._entity_type_attributes_cache = {} self._as_user_stack = _AsUserStack() # Create session if self._access_token and create_session: self.validate_server_availability() self.create_session() @property def log(self) -> logging.Logger: return self._log
[docs] def get_base_url(self): return self._base_url
[docs] def get_rest_url(self): return self._rest_url
base_url = property(get_base_url) rest_url = property(get_rest_url)
[docs] def get_ssl_verify(self): """Enable ssl verification. Returns: bool: Current state of ssl verification. """ return self._ssl_verify
[docs] def set_ssl_verify(self, ssl_verify): """Change ssl verification state. Args: ssl_verify (Union[bool, str, None]): Enabled/disable ssl verification, can be a path to file. """ if self._ssl_verify == ssl_verify: return self._ssl_verify = ssl_verify if self._session is not None: self._session.verify = ssl_verify
[docs] def get_cert(self): """Current cert file used for connection to server. Returns: Union[str, None]: Path to cert file. """ return self._cert
[docs] def set_cert(self, cert): """Change cert file used for connection to server. Args: cert (Union[str, None]): Path to cert file. """ if cert == self._cert: return self._cert = cert if self._session is not None: self._session.cert = cert
ssl_verify = property(get_ssl_verify, set_ssl_verify) cert = property(get_cert, set_cert)
[docs] @classmethod def get_default_timeout(cls): """Default value for requests timeout. Utils function 'get_default_timeout' is used by default. Returns: float: Timeout value in seconds. """ return get_default_timeout()
[docs] @classmethod def get_default_max_retries(cls): """Default value for requests max retries. First looks for environment variable SERVER_RETRIES_ENV_KEY, which can affect max retries value. If not available then use class attribute '_default_max_retries'. Returns: int: Max retries value. """ try: return int(os.environ.get(SERVER_RETRIES_ENV_KEY)) except (ValueError, TypeError): pass return cls._default_max_retries
[docs] def get_timeout(self) -> float: """Current value for requests timeout. Returns: float: Timeout value in seconds. """ return self._timeout
[docs] def set_timeout(self, timeout: Optional[float]): """Change timeout value for requests. Args: timeout (Optional[float]): Timeout value in seconds. """ if timeout is None: timeout = self.get_default_timeout() self._timeout = float(timeout)
[docs] def get_max_retries(self) -> int: """Current value for requests max retries. Returns: int: Max retries value. """ return self._max_retries
[docs] def set_max_retries(self, max_retries: Optional[int]): """Change max retries value for requests. Args: max_retries (Optional[int]): Max retries value. """ if max_retries is None: max_retries = self.get_default_max_retries() self._max_retries = int(max_retries)
timeout = property(get_timeout, set_timeout) max_retries = property(get_max_retries, set_max_retries) @property def access_token(self) -> Optional[str]: """Access token used for authorization to server. Returns: Optional[str]: Token string or None if not authorized yet. """ return self._access_token
[docs] def is_service_user(self) -> bool: """Check if connection is using service API key. Returns: bool: Used api key belongs to service user. """ if not self.has_valid_token: raise ValueError("User is not logged in.") return bool(self._access_token_is_service)
[docs] def get_site_id(self) -> Optional[str]: """Site id used for connection. Site id tells server from which machine/site is connection created and is used for default site overrides when settings are received. Returns: Optional[str]: Site id value or None if not filled. """ return self._site_id
[docs] def set_site_id(self, site_id: Optional[str]): """Change site id of connection. Behave as specific site for server. It affects default behavior of settings getter methods. Args: site_id (Optional[str]): Site id value, or 'None' to unset. """ if self._site_id == site_id: return self._site_id = site_id # Recreate session on machine id change self._update_session_headers()
site_id = property(get_site_id, set_site_id)
[docs] def get_client_version(self) -> Optional[str]: """Version of client used to connect to server. Client version is AYON client build desktop application. Returns: str: Client version string used in connection. """ return self._client_version
[docs] def set_client_version(self, client_version: Optional[str]): """Set version of client used to connect to server. Client version is AYON client build desktop application. Args: client_version (Optional[str]): Client version string. """ if self._client_version == client_version: return self._client_version = client_version self._update_session_headers()
client_version = property(get_client_version, set_client_version)
[docs] def get_default_settings_variant(self) -> str: """Default variant used for settings. Returns: Union[str, None]: name of variant or None. """ return self._default_settings_variant
[docs] def set_default_settings_variant(self, variant: str): """Change default variant for addon settings. Note: It is recommended to set only 'production' or 'staging' variants as default variant. Args: variant (str): Settings variant name. It is possible to use 'production', 'staging' or name of dev bundle. """ self._default_settings_variant = variant
default_settings_variant = property( get_default_settings_variant, set_default_settings_variant )
[docs] def get_sender(self) -> str: """Sender used to send requests. Returns: Union[str, None]: Sender name or None. """ return self._sender
[docs] def set_sender(self, sender: Optional[str]): """Change sender used for requests. Args: sender (Optional[str]): Sender name or None. """ if sender == self._sender: return self._sender = sender self._update_session_headers()
sender = property(get_sender, set_sender)
[docs] def get_sender_type(self) -> Optional[str]: """Sender type used to send requests. Sender type is supported since AYON server 1.5.5 . Returns: Optional[str]: Sender type or None. """ return self._sender_type
[docs] def set_sender_type(self, sender_type: Optional[str]): """Change sender type used for requests. Args: sender_type (Optional[str]): Sender type or None. """ if sender_type == self._sender_type: return self._sender_type = sender_type self._update_session_headers()
sender_type = property(get_sender_type, set_sender_type)
[docs] def get_default_service_username(self) -> Optional[str]: """Default username used for callbacks when used with service API key. Returns: Union[str, None]: Username if any was filled. """ return self._as_user_stack.get_default_username()
[docs] def set_default_service_username(self, username: Optional[str] = None): """Service API will work as other user. Service API keys can work as other user. It can be temporary using context manager 'as_user' or it is possible to set default username if 'as_user' context manager is not entered. Args: username (Optional[str]): Username to work as when service. Raises: ValueError: When connection is not yet authenticated or api key is not service token. """ current_username = self._as_user_stack.get_default_username() if current_username == username: return if not self.has_valid_token: raise ValueError( "Authentication of connection did not happen yet." ) if not self._access_token_is_service: raise ValueError( "Can't set service username. API key is not a service token." ) self._as_user_stack.set_default_username(username) if self._as_user_stack.username == username: self._update_session_headers()
[docs] @contextmanager def as_username( self, username: Optional[str], ignore_service_error: bool = False, ): """Service API will temporarily work as other user. This method can be used only if service API key is logged in. Args: username (Optional[str]): Username to work as when service. ignore_service_error (Optional[bool]): Ignore error when service API key is not used. Raises: ValueError: When connection is not yet authenticated or api key is not service token. """ if not self.has_valid_token: raise ValueError( "Authentication of connection did not happen yet." ) if not self._access_token_is_service: if ignore_service_error: yield None return raise ValueError( "Can't set service username. API key is not a service token." ) try: with self._as_user_stack.as_user(username) as o: self._update_session_headers() yield o finally: self._update_session_headers()
@property def is_server_available(self) -> bool: if self._server_available is None: response = requests.get( self._base_url, cert=self._cert, verify=self._ssl_verify ) self._server_available = response.status_code == 200 return self._server_available @property def has_valid_token(self) -> bool: if self._access_token is None: return False if self._token_is_valid is None: self.validate_token() return self._token_is_valid
[docs] def validate_server_availability(self): if not self.is_server_available: raise ServerNotReached( f"Server \"{self._base_url}\" can't be reached" )
[docs] def validate_token(self) -> bool: try: self._token_validation_started = True # TODO add other possible validations # - existence of 'user' key in info # - validate that 'site_id' is in 'sites' in info self.get_info() self.get_user() self._token_is_valid = True except UnauthorizedError: self._token_is_valid = False finally: self._token_validation_started = False return self._token_is_valid
[docs] def set_token(self, token: Optional[str]): self.reset_token() self._access_token = token self.get_user()
[docs] def reset_token(self): self._access_token = None self._token_is_valid = None self.close_session()
[docs] def create_session( self, ignore_existing: bool = True, force: bool = False ): """Create a connection session. Session helps to keep connection with server without need to reconnect on each call. Args: ignore_existing (bool): If session already exists, ignore creation. force (bool): If session already exists, close it and create new. """ if force and self._session is not None: self.close_session() if self._session is not None: if ignore_existing: return raise ValueError("Session is already created.") self._as_user_stack.clear() # Validate token before session creation self.validate_token() session = requests.Session() session.cert = self._cert session.verify = self._ssl_verify session.headers.update(self.get_headers()) self._session_functions_mapping = { RequestTypes.get: session.get, RequestTypes.post: session.post, RequestTypes.put: session.put, RequestTypes.patch: session.patch, RequestTypes.delete: session.delete } self._session = session
[docs] def close_session(self): if self._session is None: return session = self._session self._session = None self._session_functions_mapping = {} session.close()
def _update_session_headers(self): if self._session is None: return # Header keys that may change over time for key, value in ( ("X-as-user", self._as_user_stack.username), ("x-ayon-version", self._client_version), ("x-ayon-site-id", self._site_id), ("x-sender-type", self._sender_type), ("x-sender", self._sender), ): if value is not None: self._session.headers[key] = value elif key in self._session.headers: self._session.headers.pop(key)
[docs] def get_info(self) -> dict[str, Any]: """Get information about current used api key. By default, the 'info' contains only 'uptime' and 'version'. With logged user info also contains information about user and machines on which was logged in. Todos: Use this method for validation of token instead of 'get_user'. Returns: dict[str, Any]: Information from server. """ response = self.get("info") response.raise_for_status() return response.data
[docs] def get_server_version(self) -> str: """Get server version. Version should match semantic version (https://semver.org/). Returns: str: Server version. """ if self._server_version is None: self._server_version = self.get_info()["version"] return self._server_version
[docs] def get_server_version_tuple(self) -> ServerVersion: """Get server version as tuple. Version should match semantic version (https://semver.org/). This function only returns first three numbers of version. Returns: ServerVersion: Server version. """ if self._server_version_tuple is None: re_match = VERSION_REGEX.fullmatch( self.get_server_version()) self._server_version_tuple = ( int(re_match.group("major")), int(re_match.group("minor")), int(re_match.group("patch")), re_match.group("prerelease") or "", re_match.group("buildmetadata") or "", ) return self._server_version_tuple
server_version = property(get_server_version) server_version_tuple: ServerVersion = property( get_server_version_tuple ) @property def graphql_allows_traits_in_representations(self) -> bool: """Check server support for representation traits.""" if self._graphql_allows_traits_in_representations is None: major, minor, patch, _, _ = self.server_version_tuple self._graphql_allows_traits_in_representations = ( (major, minor, patch) >= (1, 7, 5) ) return self._graphql_allows_traits_in_representations
[docs] def is_product_base_type_supported(self) -> bool: """Product base types are available on server.""" if self._product_base_type_supported is None: major, minor, patch, _, _ = self.server_version_tuple self._product_base_type_supported = ( (major, minor, patch) >= (1, 13, 0) ) return self._product_base_type_supported
def _get_user_info(self) -> Optional[dict[str, Any]]: if self._access_token is None: return None if self._access_token_is_service is not None: response = self.get("users/me") if response.status == 200: return response.data return None self._access_token_is_service = False response = self.get("users/me") if response.status == 200: return response.data self._access_token_is_service = True response = self.get("users/me") if response.status == 200: return response.data self._access_token_is_service = None return None
[docs] def get_users( self, project_name: Optional[str] = None, usernames: Optional[Iterable[str]] = None, emails: Optional[Iterable[str]] = None, fields: Optional[Iterable[str]] = None, ) -> Generator[dict[str, Any], None, None]: """Get Users. Only administrators and managers can fetch all users. For other users it is required to pass in 'project_name' filter. Args: project_name (Optional[str]): Project name. usernames (Optional[Iterable[str]]): Filter by usernames. emails (Optional[Iterable[str]]): Filter by emails. fields (Optional[Iterable[str]]): Fields to be queried for users. Returns: Generator[dict[str, Any]]: Queried users. """ filters = {} if usernames is not None: usernames = set(usernames) if not usernames: return filters["userNames"] = list(usernames) if emails is not None: emails = set(emails) if not emails: return major, minor, patch, _, _ = self.server_version_tuple emails_filter_available = (major, minor, patch) > (1, 7, 3) if not emails_filter_available: server_version = self.get_server_version() raise ValueError( "Filtering by emails is not supported by" f" server version {server_version}." ) filters["emails"] = list(emails) if project_name is not None: filters["projectName"] = project_name if not fields: fields = self.get_default_fields_for_type("user") query = users_graphql_query(set(fields)) for attr, filter_value in filters.items(): query.set_variable_value(attr, filter_value) attributes = self.get_attributes_for_type("user") for parsed_data in query.continuous_query(self): for user in parsed_data["users"]: access_groups = user.get("accessGroups") if isinstance(access_groups, str): user["accessGroups"] = json.loads(access_groups) all_attrib = user.get("allAttrib") if isinstance(all_attrib, str): user["allAttrib"] = json.loads(all_attrib) if "attrib" in user: user["ownAttrib"] = user["attrib"].copy() attrib = user["attrib"] for key, value in tuple(attrib.items()): if value is not None: continue attr_def = attributes.get(key) if attr_def is not None: attrib[key] = attr_def["default"] yield user
[docs] def get_user_by_name( self, username: str, project_name: Optional[str] = None, fields: Optional[Iterable[str]] = None, ) -> Optional[dict[str, Any]]: """Get user by name using GraphQl. Only administrators and managers can fetch all users. For other users it is required to pass in 'project_name' filter. Args: username (str): Username. project_name (Optional[str]): Define scope of project. fields (Optional[Iterable[str]]): Fields to be queried for users. Returns: Union[dict[str, Any], None]: User info or None if user is not found. """ if not username: return None for user in self.get_users( project_name=project_name, usernames={username}, fields=fields, ): return user return None
[docs] def get_user( self, username: Optional[str] = None ) -> Optional[dict[str, Any]]: """Get user info using REST endpoint. User contains only explicitly set attributes in 'attrib'. Args: username (Optional[str]): Username. Returns: Optional[dict[str, Any]]: User info or None if user is not found. """ if username is None: user = self._get_user_info() if user is None: raise UnauthorizedError("User is not authorized.") else: response = self.get(f"users/{username}") response.raise_for_status() user = response.data # NOTE Server does return only filled attributes right now. # This would fill all missing attributes with 'None'. # for attr_name in self.get_attributes_for_type("user"): # user["attrib"].setdefault(attr_name, None) fill_own_attribs(user) return user
[docs] def get_headers( self, content_type: Optional[str] = None ) -> dict[str, str]: if content_type is None: content_type = "application/json" headers = { "Content-Type": content_type, "x-ayon-platform": platform.system().lower(), "x-ayon-hostname": get_machine_name(), "referer": self.get_base_url(), } if self._site_id is not None: headers["x-ayon-site-id"] = self._site_id if self._client_version is not None: headers["x-ayon-version"] = self._client_version if self._sender_type is not None: headers["x-sender-type"] = self._sender_type if self._sender is not None: headers["x-sender"] = self._sender if self._access_token: if self._access_token_is_service: headers["X-Api-Key"] = self._access_token username = self._as_user_stack.username if username: headers["X-as-user"] = username else: headers["Authorization"] = f"Bearer {self._access_token}" return headers
[docs] def login( self, username: str, password: str, create_session: bool = True ): """Login to server. Args: username (str): Username. password (str): Password. create_session (Optional[bool]): Create session after login. Default: True. Raises: AuthenticationError: Login failed. """ if self.has_valid_token: try: user_info = self.get_user() except UnauthorizedError: user_info = {} current_username = user_info.get("name") if current_username == username: self.close_session() if create_session: self.create_session() return self.reset_token() self.validate_server_availability() self._token_validation_started = True try: response = self.post( "auth/login", name=username, password=password ) if response.status_code != 200: _detail = response.data.get("detail") details = "" if _detail: details = f" {_detail}" raise AuthenticationError(f"Login failed {details}") finally: self._token_validation_started = False self._access_token = response["token"] if not self.has_valid_token: raise AuthenticationError("Invalid credentials") if create_session: self.create_session()
[docs] def logout(self, soft: bool = False): if self._access_token: if not soft: self._logout() self.reset_token()
def _logout(self): logout_from_server(self._base_url, self._access_token) def _do_rest_request(self, function, url, **kwargs): kwargs.setdefault("timeout", self.timeout) max_retries = kwargs.get("max_retries", self.max_retries) if max_retries < 1: max_retries = 1 if self._session is None: # Validate token if was not yet validated # - ignore validation if we're in middle of # validation if ( self._token_is_valid is None and not self._token_validation_started ): self.validate_token() if "headers" not in kwargs: kwargs["headers"] = self.get_headers() if isinstance(function, RequestType): function = self._base_functions_mapping[function] elif isinstance(function, RequestType): function = self._session_functions_mapping[function] response = None new_response = None for retry_idx in reversed(range(max_retries)): try: response = function(url, **kwargs) break except ConnectionRefusedError: if retry_idx == 0: self.log.warning( "Connection error happened.", exc_info=True ) # Server may be restarting new_response = RestApiResponse( None, { "detail": ( "Unable to connect the server. Connection refused" ) } ) except requests.exceptions.Timeout: # Connection timed out new_response = RestApiResponse( None, {"detail": "Connection timed out."} ) except requests.exceptions.ConnectionError: # Log warning only on last attempt if retry_idx == 0: self.log.warning( "Connection error happened.", exc_info=True ) new_response = RestApiResponse( None, { "detail": ( "Unable to connect the server. Connection error" ) } ) time.sleep(0.1) if new_response is not None: return new_response new_response = RestApiResponse(response) self.log.debug(f"Response {str(new_response)}") return new_response
[docs] def raw_post(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [POST] {url}") return self._do_rest_request( RequestTypes.post, url, **kwargs )
[docs] def raw_put(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [PUT] {url}") return self._do_rest_request( RequestTypes.put, url, **kwargs )
[docs] def raw_patch(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [PATCH] {url}") return self._do_rest_request( RequestTypes.patch, url, **kwargs )
[docs] def raw_get(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [GET] {url}") return self._do_rest_request( RequestTypes.get, url, **kwargs )
[docs] def raw_delete(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [DELETE] {url}") return self._do_rest_request( RequestTypes.delete, url, **kwargs )
[docs] def post(self, entrypoint: str, **kwargs): return self.raw_post(entrypoint, json=kwargs)
[docs] def put(self, entrypoint: str, **kwargs): return self.raw_put(entrypoint, json=kwargs)
[docs] def patch(self, entrypoint: str, **kwargs): return self.raw_patch(entrypoint, json=kwargs)
[docs] def get(self, entrypoint: str, **kwargs): return self.raw_get(entrypoint, params=kwargs)
[docs] def delete(self, entrypoint: str, **kwargs): return self.raw_delete(entrypoint, params=kwargs)
def _endpoint_to_url( self, endpoint: str, use_rest: Optional[bool] = True ) -> str: """Cleanup endpoint and return full url to AYON server. If endpoint already starts with server url only slashes are removed. Args: endpoint (str): Endpoint to be cleaned. use_rest (Optional[bool]): Use only base server url if set to False, otherwise REST endpoint is used. Returns: str: Full url to AYON server. """ endpoint = endpoint.lstrip("/").rstrip("/") if endpoint.startswith(self._base_url): return endpoint base_url = self._rest_url if use_rest else self._graphql_url return f"{base_url}/{endpoint}" def _download_file_to_stream( self, url: str, stream, chunk_size, progress ): kwargs = {"stream": True} if self._session is None: kwargs["headers"] = self.get_headers() get_func = self._base_functions_mapping[RequestTypes.get] else: get_func = self._session_functions_mapping[RequestTypes.get] with get_func(url, **kwargs) as response: response.raise_for_status() progress.set_content_size(response.headers["Content-length"]) for chunk in response.iter_content(chunk_size=chunk_size): stream.write(chunk) progress.add_transferred_chunk(len(chunk))
[docs] def download_file_to_stream( self, endpoint: str, stream: StreamType, chunk_size: Optional[int] = None, progress: Optional[TransferProgress] = None, ) -> TransferProgress: """Download file from AYON server to IOStream. Endpoint can be full url (must start with 'base_url' of api object). Progress object can be used to track download. Can be used when download happens in thread and other thread want to catch changes over time. Todos: Use retries and timeout. Return RestApiResponse. Args: endpoint (str): Endpoint or URL to file that should be downloaded. stream (StreamType): Stream where output will be stored. chunk_size (Optional[int]): Size of chunks that are received in single loop. progress (Optional[TransferProgress]): Object that gives ability to track download progress. """ if not chunk_size: chunk_size = self.default_download_chunk_size url = self._endpoint_to_url(endpoint) if progress is None: progress = TransferProgress() progress.set_source_url(url) progress.set_started() try: self._download_file_to_stream( url, stream, chunk_size, progress ) except Exception as exc: progress.set_failed(str(exc)) raise finally: progress.set_transfer_done() return progress
[docs] def download_file( self, endpoint: str, filepath: str, chunk_size: Optional[int] = None, progress: Optional[TransferProgress] = None, ) -> TransferProgress: """Download file from AYON server. Endpoint can be full url (must start with 'base_url' of api object). Progress object can be used to track download. Can be used when download happens in thread and other thread want to catch changes over time. Todos: Use retries and timeout. Return RestApiResponse. Args: endpoint (str): Endpoint or URL to file that should be downloaded. filepath (str): Path where file will be downloaded. chunk_size (Optional[int]): Size of chunks that are received in single loop. progress (Optional[TransferProgress]): Object that gives ability to track download progress. """ # Create dummy object so the function does not have to check # 'progress' variable everywhere if progress is None: progress = TransferProgress() progress.set_destination_url(filepath) dst_directory = os.path.dirname(filepath) os.makedirs(dst_directory, exist_ok=True) try: with open(filepath, "wb") as stream: self.download_file_to_stream( endpoint, stream, chunk_size, progress ) except Exception as exc: progress.set_failed(str(exc)) raise return progress
@staticmethod def _upload_chunks_iter( file_stream: StreamType, progress: TransferProgress, chunk_size: int, ) -> Generator[bytes, None, None]: """Generator that yields chunks of file. Args: file_stream (StreamType): Byte stream. progress (TransferProgress): Object to track upload progress. chunk_size (int): Size of chunks that are uploaded at once. Yields: bytes: Chunk of file. """ # Get size of file file_stream.seek(0, io.SEEK_END) size = file_stream.tell() file_stream.seek(0) # Set content size to progress object progress.set_content_size(size) while True: chunk = file_stream.read(chunk_size) if not chunk: break progress.add_transferred_chunk(len(chunk)) yield chunk def _upload_file( self, url: str, stream: StreamType, progress: TransferProgress, request_type: Optional[RequestType] = None, chunk_size: Optional[int] = None, **kwargs ) -> requests.Response: """Upload file to server. Args: url (str): Url where file will be uploaded. stream (StreamType): File stream. progress (TransferProgress): Object that gives ability to track progress. request_type (Optional[RequestType]): Type of request that will be used. Default is PUT. chunk_size (Optional[int]): Size of chunks that are uploaded at once. **kwargs (Any): Additional arguments that will be passed to request function. Returns: requests.Response: Server response. """ if request_type is None: request_type = RequestTypes.put if self._session is None: headers = kwargs.setdefault("headers", {}) for key, value in self.get_headers().items(): if key not in headers: headers[key] = value post_func = self._base_functions_mapping[request_type] else: post_func = self._session_functions_mapping[request_type] if not chunk_size: chunk_size = self.default_upload_chunk_size response = post_func( url, data=self._upload_chunks_iter(stream, progress, chunk_size), **kwargs ) response.raise_for_status() return response
[docs] def upload_file_from_stream( self, endpoint: str, stream: StreamType, progress: Optional[TransferProgress] = None, request_type: Optional[RequestType] = None, **kwargs ) -> requests.Response: """Upload file to server from bytes. Todos: Use retries and timeout. Return RestApiResponse. Args: endpoint (str): Endpoint or url where file will be uploaded. stream (StreamType): File content stream. progress (Optional[TransferProgress]): Object that gives ability to track upload progress. request_type (Optional[RequestType]): Type of request that will be used to upload file. **kwargs (Any): Additional arguments that will be passed to request function. Returns: requests.Response: Response object """ url = self._endpoint_to_url(endpoint) # Create dummy object so the function does not have to check # 'progress' variable everywhere if progress is None: progress = TransferProgress() progress.set_destination_url(url) progress.set_started() try: return self._upload_file( url, stream, progress, request_type, **kwargs ) except Exception as exc: progress.set_failed(str(exc)) raise finally: progress.set_transfer_done()
[docs] def upload_file( self, endpoint: str, filepath: str, progress: Optional[TransferProgress] = None, request_type: Optional[RequestType] = None, **kwargs ) -> requests.Response: """Upload file to server. Todos: Use retries and timeout. Return RestApiResponse. Args: endpoint (str): Endpoint or url where file will be uploaded. filepath (str): Source filepath. progress (Optional[TransferProgress]): Object that gives ability to track upload progress. request_type (Optional[RequestType]): Type of request that will be used to upload file. **kwargs (Any): Additional arguments that will be passed to request function. Returns: requests.Response: Response object """ if progress is None: progress = TransferProgress() progress.set_source_url(filepath) with open(filepath, "rb") as stream: return self.upload_file_from_stream( endpoint, stream, progress, request_type, **kwargs )
[docs] def upload_reviewable( self, project_name: str, version_id: str, filepath: str, label: Optional[str] = None, content_type: Optional[str] = None, filename: Optional[str] = None, progress: Optional[TransferProgress] = None, headers: Optional[dict[str, Any]] = None, **kwargs ) -> requests.Response: """Upload reviewable file to server. Args: project_name (str): Project name. version_id (str): Version id. filepath (str): Reviewable file path to upload. label (Optional[str]): Reviewable label. Filled automatically server side with filename. content_type (Optional[str]): MIME type of the file. filename (Optional[str]): User as original filename. Filename from 'filepath' is used when not filled. progress (Optional[TransferProgress]): Progress. headers (Optional[dict[str, Any]]): Headers. Returns: requests.Response: Server response. """ if not content_type: content_type = get_media_mime_type(filepath) if not content_type: raise ValueError( f"Could not determine MIME type of file '{filepath}'" ) if headers is None: headers = self.get_headers(content_type) else: # Make sure content-type is filled with file content type content_type_key = next( ( key for key in headers if key.lower() == "content-type" ), "Content-Type" ) headers[content_type_key] = content_type # Fill original filename if not explicitly defined if not filename: filename = os.path.basename(filepath) headers["x-file-name"] = filename query = prepare_query_string({"label": label or None}) endpoint = ( f"/projects/{project_name}" f"/versions/{version_id}/reviewables{query}" ) return self.upload_file( endpoint, filepath, progress=progress, headers=headers, request_type=RequestTypes.post, **kwargs )
[docs] def trigger_server_restart(self): """Trigger server restart. Restart may be required when a change of specific value happened on server. """ result = self.post("system/restart") if result.status_code != 204: # TODO add better exception raise ValueError("Failed to restart server")
[docs] def query_graphql( self, query: str, variables: Optional[dict[str, Any]] = None, ) -> GraphQlResponse: """Execute GraphQl query. Args: query (str): GraphQl query string. variables (Optional[dict[str, Any]): Variables that can be used in query. Returns: GraphQlResponse: Response from server. """ data = {"query": query, "variables": variables or {}} response = self._do_rest_request( RequestTypes.post, self._graphql_url, json=data ) response.raise_for_status() return GraphQlResponse(response)
[docs] def get_graphql_schema(self) -> dict[str, Any]: return self.query_graphql(INTROSPECTION_QUERY).data["data"]
[docs] def get_server_schema(self) -> Optional[dict[str, Any]]: """Get server schema with info, url paths, components etc. Todos: Cache schema - How to find out it is outdated? Returns: dict[str, Any]: Full server schema. """ url = f"{self._base_url}/openapi.json" response = self._do_rest_request(RequestTypes.get, url) if response: return response.data return None
[docs] def get_schemas(self) -> dict[str, Any]: """Get components schema. Name of components does not match entity type names e.g. 'project' is under 'ProjectModel'. We should find out some mapping. Also, there are properties which don't have information about reference to object e.g. 'config' has just object definition without reference schema. Returns: dict[str, Any]: Component schemas. """ server_schema = self.get_server_schema() return server_schema["components"]["schemas"]
[docs] def get_default_fields_for_type(self, entity_type: str) -> set[str]: """Default fields for entity type. Returns most of commonly used fields from server. Args: entity_type (str): Name of entity type. Returns: set[str]: Fields that should be queried from server. """ # Event does not have attributes if entity_type == "event": return set(DEFAULT_EVENT_FIELDS) if entity_type == "activity": return set(DEFAULT_ACTIVITY_FIELDS) if entity_type == "project": entity_type_defaults = set(DEFAULT_PROJECT_FIELDS) maj_v, min_v, patch_v, _, _ = self.server_version_tuple if (maj_v, min_v, patch_v) > (1, 10, 0): entity_type_defaults.add("productTypes") elif entity_type == "folder": entity_type_defaults = set(DEFAULT_FOLDER_FIELDS) elif entity_type == "task": entity_type_defaults = set(DEFAULT_TASK_FIELDS) elif entity_type == "product": entity_type_defaults = set(DEFAULT_PRODUCT_FIELDS) maj_v, min_v, patch_v, _, _ = self.server_version_tuple if self.is_product_base_type_supported(): entity_type_defaults.add("productBaseType") elif entity_type == "version": entity_type_defaults = set(DEFAULT_VERSION_FIELDS) elif entity_type == "representation": entity_type_defaults = ( DEFAULT_REPRESENTATION_FIELDS | REPRESENTATION_FILES_FIELDS ) if not self.graphql_allows_traits_in_representations: entity_type_defaults.discard("traits") elif entity_type == "productType": entity_type_defaults = set(DEFAULT_PRODUCT_TYPE_FIELDS) elif entity_type == "workfile": entity_type_defaults = set(DEFAULT_WORKFILE_INFO_FIELDS) elif entity_type == "user": entity_type_defaults = set(DEFAULT_USER_FIELDS) elif entity_type == "entityList": entity_type_defaults = set(DEFAULT_ENTITY_LIST_FIELDS) else: raise ValueError(f"Unknown entity type \"{entity_type}\"") return ( entity_type_defaults | self.get_attributes_fields_for_type(entity_type) )
[docs] def get_rest_entity_by_id( self, project_name: str, entity_type: str, entity_id: str, ) -> Optional[AnyEntityDict]: """Get entity using REST on a project by its id. Args: project_name (str): Name of project where entity is. entity_type (Literal["folder", "task", "product", "version"]): The entity type which should be received. entity_id (str): Id of entity. Returns: Optional[AnyEntityDict]: Received entity data. """ if not all((project_name, entity_type, entity_id)): return None response = self.get( f"projects/{project_name}/{entity_type}s/{entity_id}" ) if response.status == 200: return response.data return None
# --- Batch operations processing ---
[docs] def send_batch_operations( self, project_name: str, operations: list[dict[str, Any]], can_fail: bool = False, raise_on_fail: bool = True, ) -> list[dict[str, Any]]: """Post multiple CRUD operations to server. When multiple changes should be made on server side this is the best way to go. It is possible to pass multiple operations to process on a server side and do the changes in a transaction. Args: project_name (str): On which project should be operations processed. operations (list[dict[str, Any]]): Operations to be processed. can_fail (Optional[bool]): Server will try to process all operations even if one of them fails. raise_on_fail (Optional[bool]): Raise exception if an operation fails. You can handle failed operations on your own when set to 'False'. Raises: ValueError: Operations can't be converted to json string. FailedOperations: When output does not contain server operations or 'raise_on_fail' is enabled and any operation fails. Returns: list[dict[str, Any]]: Operations result with process details. """ return self._send_batch_operations( f"projects/{project_name}/operations", operations, can_fail, raise_on_fail, )
[docs] def send_background_batch_operations( self, project_name: str, operations: list[dict[str, Any]], *, can_fail: bool = False, wait: bool = False, raise_on_fail: bool = True, ) -> BackgroundOperationTask: """Post multiple CRUD operations to server. When multiple changes should be made on server side this is the best way to go. It is possible to pass multiple operations to process on a server side and do the changes in a transaction. Compared to 'send_batch_operations' this function creates a task on server which then can be periodically checked for a status and receive it's result. When used with 'wait' set to 'True' this method blocks until task is finished. Which makes it work as 'send_batch_operations' but safer for large operations batch as is not bound to response timeout. Args: project_name (str): On which project should be operations processed. operations (list[dict[str, Any]]): Operations to be processed. can_fail (Optional[bool]): Server will try to process all operations even if one of them fails. wait (bool): Wait for operations to end. raise_on_fail (Optional[bool]): Raise exception if an operation fails. You can handle failed operations on your own when set to 'False'. Used when 'wait' is enabled. Raises: ValueError: Operations can't be converted to json string. FailedOperations: When output does not contain server operations or 'raise_on_fail' is enabled and any operation fails. Returns: BackgroundOperationTask: Background operation. """ operations_body = self._prepare_operations_body(operations) response = self.post( f"projects/{project_name}/operations/background", operations=operations_body, canFail=can_fail ) response.raise_for_status() if not wait: return response.data task_id = response["id"] time.sleep(0.1) while True: op_status = self.get_background_operations_status( project_name, task_id ) if op_status["status"] == "completed": break time.sleep(1) if raise_on_fail: self._validate_operations_result( op_status["result"], operations_body ) return op_status
[docs] def get_background_operations_status( self, project_name: str, task_id: str ) -> BackgroundOperationTask: """Get status of background operations task. Args: project_name (str): Project name. task_id (str): Backgorund operation task id. Returns: BackgroundOperationTask: Background operation. """ response = self.get( f"projects/{project_name}/operations/background/{task_id}" ) response.raise_for_status() return response.data
def _prepare_operations_body( self, operations: list[dict[str, Any]] ) -> list[dict[str, Any]]: operations_body = [] for operation in operations: if not operation: continue op_id = operation.get("id") if not op_id: op_id = create_entity_id() operation["id"] = op_id try: body = json.loads( json.dumps(operation, default=entity_data_json_default) ) except (TypeError, ValueError): raise ValueError("Couldn't json parse body: {}".format( json.dumps( operation, indent=4, default=failed_json_default ) )) operations_body.append(body) return operations_body def _send_batch_operations( self, uri: str, operations: list[dict[str, Any]], can_fail: bool, raise_on_fail: bool ) -> list[dict[str, Any]]: if not operations: return [] operations_body = self._prepare_operations_body(operations) if not operations_body: return [] response = self.post( uri, operations=operations_body, canFail=can_fail ) op_results = response.get("operations") if op_results is None: detail = response.get("detail") if detail: raise FailedOperations(f"Operation failed. Detail: {detail}") raise FailedOperations( f"Operation failed. Content: {response.text}" ) if raise_on_fail: self._validate_operations_result(response.data, operations_body) return op_results def _validate_operations_result( self, result: dict[str, Any], operations_body: list[dict[str, Any]], ) -> None: if result.get("success"): return None print(result) for op_result in result["operations"]: if op_result["success"]: continue operation_id = op_result["id"] operation = next( op for op in operations_body if op["id"] == operation_id ) detail = op_result["detail"] raise FailedOperations( f"Operation \"{operation_id}\" failed with data:" f"\n{json.dumps(operation, indent=4)}" f"\nDetail: {detail}." ) def _prepare_fields( self, entity_type: str, fields: set[str], own_attributes: bool = False ): if not fields: return if "attrib" in fields: fields.remove("attrib") fields |= self.get_attributes_fields_for_type(entity_type) if own_attributes and entity_type in {"project", "folder", "task"}: fields.add("ownAttrib") if entity_type != "project": return # Use 'data' to fill 'bundle' data if "bundle" in fields: fields.remove("bundle") fields.add("data") maj_v, min_v, patch_v, _, _ = self.server_version_tuple if "folderTypes" in fields: fields.remove("folderTypes") folder_types_fields = set(DEFAULT_FOLDER_TYPE_FIELDS) if (maj_v, min_v, patch_v) > (1, 10, 0): folder_types_fields |= {"shortName"} fields |= {f"folderTypes.{name}" for name in folder_types_fields} if "taskTypes" in fields: fields.remove("taskTypes") task_types_fields = set(DEFAULT_TASK_TYPE_FIELDS) if (maj_v, min_v, patch_v) > (1, 10, 0): task_types_fields |= {"color", "icon", "shortName"} fields |= {f"taskTypes.{name}" for name in task_types_fields} for field, default_fields in ( ("statuses", DEFAULT_PROJECT_STATUSES_FIELDS), ("tags", DEFAULT_PROJECT_TAGS_FIELDS), ("linkTypes", DEFAULT_PROJECT_TAGS_FIELDS), ): if (maj_v, min_v, patch_v) <= (1, 10, 0): break if field in fields: fields.remove(field) fields |= {f"{field}.{name}" for name in default_fields} if "productTypes" in fields: fields.remove("productTypes") fields |= { f"productTypes.{name}" for name in self.get_default_fields_for_type( "productType" ) } def _convert_entity_data(self, entity: AnyEntityDict): if not entity or "data" not in entity: return entity_data = entity["data"] or {} if isinstance(entity_data, str): entity_data = json.loads(entity_data) entity["data"] = entity_data