Source code for ayon_api.graphql

from __future__ import annotations

import copy
import numbers
from abc import ABC, abstractmethod
import typing
from typing import Optional, Iterable, Any, Generator

from .exceptions import GraphQlQueryFailed
from .utils import SortOrder

if typing.TYPE_CHECKING:
    from typing import Union

    from .server_api import ServerAPI

FIELD_VALUE = object()


[docs]def fields_to_dict(fields: Optional[Iterable[str]]) -> dict: output = {} if not fields: return output for field in fields: hierarchy = field.split(".") last = hierarchy.pop(-1) value = output for part in hierarchy: if value is FIELD_VALUE: break if part not in value: value[part] = {} value = value[part] if value is not FIELD_VALUE: value[last] = FIELD_VALUE return output
[docs]class QueryVariable: """Object representing single varible used in GraphQlQuery. Variable definition is in GraphQl query header but it's value is used in fields. Args: variable_name (str): Name of variable in query. """ def __init__(self, variable_name: str) -> None: self._variable_name = variable_name self._name = f"${variable_name}" @property def name(self) -> str: """Name used in field filter.""" return self._name @property def variable_name(self) -> str: """Name of variable in query definition.""" return self._variable_name def __hash__(self): return self._name.__hash__() def __str__(self) -> str: return self._name def __format__(self, *args, **kwargs) -> str: return self._name.__format__(*args, **kwargs)
[docs]class GraphQlQuery: """GraphQl query which can have fields to query. Single use object which can be used only for one query. Object and children objects keep track about paging and progress. Args: name (str): Name of query. """ offset = 2 def __init__(self, name: str, order: Optional[int] = None) -> None: self._name = name self._variables = {} self._children = [] self._has_multiple_edge_fields = None self._order = SortOrder.parse_value(order, SortOrder.ascending) @property def indent(self) -> int: """Indentation for preparation of query string. Returns: int: Ident spaces. """ return 0 @property def child_indent(self) -> int: """Indentation for preparation of query string used by children. Returns: int: Ident spaces for children. """ return self.indent @property def need_query(self) -> bool: """Still need query from server. Needed for edges which use pagination. Returns: bool: If still need query from server. """ for child in self._children: if child.need_query: return True return False @property def has_multiple_edge_fields(self) -> bool: if self._has_multiple_edge_fields is None: edge_counter = 0 for child in self._children: edge_counter += child.sum_edge_fields(2) if edge_counter > 1: break self._has_multiple_edge_fields = edge_counter > 1 return self._has_multiple_edge_fields
[docs] def add_variable( self, key: str, value_type: str, value: Optional[Any] = None ) -> QueryVariable: """Add variable to query. Args: key (str): Variable name. value_type (str): Type of expected value in variables. This is graphql type e.g. "[String!]", "Int", "Boolean", etc. value (Any): Default value for variable. Can be changed later. Returns: QueryVariable: Created variable object. Raises: KeyError: If variable was already added before. """ if key in self._variables: raise KeyError( "Variable \"{}\" was already set with type {}.".format( key, value_type ) ) variable = QueryVariable(key) self._variables[key] = { "type": value_type, "variable": variable, "value": value } return variable
[docs] def get_variable(self, key: str) -> QueryVariable: """Variable object. Args: key (str): Variable name added to headers. Returns: QueryVariable: Variable object used in query string. """ return self._variables[key]["variable"]
[docs] def get_variable_value( self, key: str, default: Optional[Any] = None ) -> Any: """Get Current value of variable. Args: key (str): Variable name. default (Any): Default value if variable is available. Returns: Any: Variable value. """ variable_item = self._variables.get(key) if variable_item: return variable_item["value"] return default
[docs] def set_variable_value(self, key: str, value: Any) -> None: """Set value for variable. Args: key (str): Variable name under which the value is stored. value (Any): Variable value used in query. Variable is not used if value is 'None'. """ self._variables[key]["value"] = value
[docs] def get_variable_keys(self) -> set[str]: """Get all variable keys. Returns: set[str]: Variable keys. """ return set(self._variables.keys())
[docs] def get_variables_values(self) -> dict[str, Any]: """Calculate variable values used that should be used in query. Variables with value set to 'None' are skipped. Returns: dict[str, Any]: Variable values by their name. """ output = {} for key, item in self._variables.items(): value = item["value"] if value is not None: output[key] = item["value"] return output
[docs] def add_obj_field(self, field: BaseGraphQlQueryField) -> None: """Add field object to children. Args: field (BaseGraphQlQueryField): Add field to query children. """ if field in self._children: return self._children.append(field) field.set_parent(self)
[docs] def add_field_with_edges(self, name: str) -> GraphQlQueryEdgeField: """Add field with edges to query. Args: name (str): Field name e.g. 'tasks'. Returns: GraphQlQueryEdgeField: Created field object. """ item = GraphQlQueryEdgeField(name, self, self._order) self.add_obj_field(item) return item
[docs] def add_field(self, name: str) -> GraphQlQueryField: """Add field to query. Args: name (str): Field name e.g. 'id'. Returns: GraphQlQueryField: Created field object. """ item = GraphQlQueryField(name, self, self._order) self.add_obj_field(item) return item
[docs] def get_field_by_keys( self, keys: Iterable[str] ) -> Optional[BaseGraphQlQueryField]: keys = list(keys) if not keys: return None key = keys.pop(0) for child in self._children: if child.name == key: return child.get_field_by_keys(keys) return None
[docs] def get_field_by_path( self, path: str ) -> Optional[BaseGraphQlQueryField]: return self.get_field_by_keys(path.split("/"))
[docs] def calculate_query(self) -> str: """Calculate query string which is sent to server. Returns: str: GraphQl string with variables and headers. Raises: ValueError: Query has no fiels. """ if not self._children: raise ValueError("Missing fields to query") variables = [] for item in self._variables.values(): if item["value"] is None: continue variables.append(f"{item['variable']}: {item['type']}") variables_str = "" if variables: variables_str = f"({','.join(variables)})" header = f"query {self._name}{variables_str}" output = [] output.append(header + " {") for field in self._children: output.append(field.calculate_query()) output.append("}") return "\n".join(output)
[docs] def parse_result( self, data: dict[str, Any], output: dict[str, Any], progress_data: dict[str, Any], ) -> None: """Parse data from response for output. Output is stored to passed 'output' variable. That's because of paging during which objects must have access to both new and previous values. Args: data (dict[str, Any]): Data received using calculated query. output (dict[str, Any]): Where parsed data are stored. progress_data (dict[str, Any]): Data used for paging. """ if not data: return for child in self._children: child.parse_result(data, output, progress_data)
[docs] def query(self, con: ServerAPI) -> dict[str, Any]: """Do a query from server. Args: con (ServerAPI): Connection to server with 'query' method. Returns: dict[str, Any]: Parsed output from GraphQl query. """ progress_data = {} output = {} while self.need_query: query_str = self.calculate_query() variables = self.get_variables_values() response = con.query_graphql( query_str, self.get_variables_values() ) if response.errors: raise GraphQlQueryFailed(response.errors, query_str, variables) self.parse_result(response.data["data"], output, progress_data) return output
[docs] def continuous_query( self, con: ServerAPI ) -> Generator[dict[str, Any], None, None]: """Do a query from server. Args: con (ServerAPI): Connection to server with 'query' method. Returns: dict[str, Any]: Parsed output from GraphQl query. """ progress_data = {} if self.has_multiple_edge_fields: output = {} while self.need_query: query_str = self.calculate_query() variables = self.get_variables_values() response = con.query_graphql(query_str, variables) if response.errors: raise GraphQlQueryFailed( response.errors, query_str, variables ) self.parse_result(response.data["data"], output, progress_data) yield output else: while self.need_query: output = {} query_str = self.calculate_query() variables = self.get_variables_values() response = con.query_graphql(query_str, variables) if response.errors: raise GraphQlQueryFailed( response.errors, query_str, variables ) self.parse_result(response.data["data"], output, progress_data) yield output
[docs]class BaseGraphQlQueryField(ABC): """Field in GraphQl query. Args: name (str): Name of field. parent (Union[BaseGraphQlQueryField, GraphQlQuery]): Parent object of a field. """ def __init__( self, name: str, parent: Union[BaseGraphQlQueryField, GraphQlQuery], order: SortOrder, ): if isinstance(parent, GraphQlQuery): query_item = parent else: query_item = parent.query_item self._name = name self._parent = parent self._filters = {} self._children = [] # Value is changed on first parse of result self._need_query = True self._query_item = query_item self._path = None self._limit = None self._order = order self._fetched_counter = 0 def __repr__(self): return f"<{self.__class__.__name__} {self.path}>"
[docs] def get_name(self) -> str: return self._name
name = property(get_name)
[docs] def get_field_by_keys( self, keys: Iterable[str] ) -> Optional[BaseGraphQlQueryField]: keys = list(keys) if not keys: return self key = keys.pop(0) for child in self._children: if child.name == key: return child.get_field_by_keys(keys) return None
[docs] def set_limit(self, limit: Optional[int]) -> None: self._limit = limit
[docs] def set_order(self, order: SortOrder) -> None: order = SortOrder.parse_value(order) if order is None: raise ValueError( f"Got invalid value {order}." f" Expected {SortOrder.ascending} or {SortOrder.descending}" ) self._order = order
[docs] def set_ascending_order(self, enabled: bool = True) -> None: self.set_order( SortOrder.ascending if enabled else SortOrder.descending )
[docs] def set_descending_order(self, enabled: bool = True) -> None: self.set_ascending_order(not enabled)
[docs] def add_variable( self, key: str, value_type: str, value: Optional[Any] = None, ) -> QueryVariable: """Add variable to query. Args: key (str): Variable name. value_type (str): Type of expected value in variables. This is graphql type e.g. "[String!]", "Int", "Boolean", etc. value (Any): Default value for variable. Can be changed later. Returns: QueryVariable: Created variable object. Raises: KeyError: If variable was already added before. """ return self._parent.add_variable(key, value_type, value)
[docs] def get_variable(self, key: str) -> QueryVariable: """Variable object. Args: key (str): Variable name added to headers. Returns: QueryVariable: Variable object used in query string. """ return self._parent.get_variable(key)
@property def need_query(self) -> bool: """Still need query from server. Needed for edges which use pagination. Look into children values too. Returns: bool: If still need query from server. """ if self._need_query: return True for child in self._children_iter(): if child.need_query: return True return False def _children_iter(self) -> Generator[BaseGraphQlQueryField, None, None]: """Iterate over all children fields of object. Returns: Iterator[BaseGraphQlQueryField]: Children fields. """ for child in self._children: yield child
[docs] def sum_edge_fields(self, max_limit: Optional[int] = None) -> int: """Check how many edge fields query has. In case there are multiple edge fields or are nested the query can't yield mid cursor results. Args: max_limit (int): Skip rest of counting if counter is bigger then entered number. Returns: int: Counter edge fields """ counter = 0 if isinstance(self, GraphQlQueryEdgeField): counter = 1 for child in self._children_iter(): counter += child.sum_edge_fields(max_limit) if max_limit is not None and counter >= max_limit: break return counter
@property def offset(self) -> int: return self._query_item.offset @property def indent(self) -> int: return self._parent.child_indent + self.offset @property @abstractmethod def child_indent(self) -> int: pass @property def query_item(self) -> GraphQlQuery: return self._query_item @property @abstractmethod def has_edges(self) -> bool: pass @property def child_has_edges(self) -> bool: for child in self._children_iter(): if child.has_edges or child.child_has_edges: return True return False @property def path(self) -> str: """Field path for debugging purposes. Returns: str: Field path in query. """ if self._path is None: if isinstance(self._parent, GraphQlQuery): path = self._name else: path = "/".join((self._parent.path, self._name)) self._path = path return self._path
[docs] def reset_cursor(self) -> None: for child in self._children_iter(): child.reset_cursor()
[docs] def get_variable_value( self, key: str, default: Optional[Any] = None ) -> Any: return self._query_item.get_variable_value(key, default)
[docs] def set_variable_value(self, key: str, value: Any) -> None: self._query_item.set_variable_value(key, value)
[docs] def set_filter(self, key: str, value: Any) -> None: self._filters[key] = value
[docs] def has_filter(self, key: str) -> bool: return key in self._filters
[docs] def remove_filter(self, key: str) -> None: self._filters.pop(key, None)
[docs] def set_parent( self, parent: Union[BaseGraphQlQueryField, GraphQlQuery] ) -> None: if self._parent is parent: return self._parent = parent parent.add_obj_field(self)
[docs] def add_obj_field(self, field: BaseGraphQlQueryField) -> None: if field in self._children: return self._children.append(field) field.set_parent(self)
[docs] def add_field_with_edges(self, name: str) -> GraphQlQueryEdgeField: item = GraphQlQueryEdgeField(name, self, self._order) self.add_obj_field(item) return item
[docs] def add_field(self, name: str) -> GraphQlQueryField: item = GraphQlQueryField(name, self, self._order) self.add_obj_field(item) return item
def _filter_value_to_str(self, value: Any) -> Optional[str]: if isinstance(value, QueryVariable): if self.get_variable_value(value.variable_name) is None: return None return str(value) if isinstance(value, numbers.Number): return str(value) if isinstance(value, str): return f'"{value}"' if isinstance(value, (list, set, tuple)): joined_values = ", ".join( self._filter_value_to_str(item) for item in iter(value) ) return f"[{joined_values}]" raise TypeError( "Unknown type to convert '{}'".format(str(type(value))) )
[docs] def get_filters(self) -> dict[str, Any]: """Receive filters for item. By default just use copy of set filters. Returns: dict[str, Any]: Fields filters. """ return copy.deepcopy(self._filters)
def _filters_to_string(self) -> str: filters = self.get_filters() if not filters: return "" filter_items = [] for key, value in filters.items(): string_value = self._filter_value_to_str(value) if string_value is None: continue filter_items.append(f"{key}: {string_value}") if not filter_items: return "" joined_items = ", ".join(filter_items) return f"({joined_items})" def _fake_children_parse(self) -> None: """Mark children as they don't need query.""" for child in self._children_iter(): child.parse_result({}, {}, {})
[docs] @abstractmethod def calculate_query(self) -> str: pass
[docs] @abstractmethod def parse_result( self, data: dict[str, Any], output: dict[str, Any], progress_data: dict[str, Any], ) -> None: pass
[docs]class GraphQlQueryField(BaseGraphQlQueryField): has_edges = False @property def child_indent(self) -> int: return self.indent
[docs] def parse_result( self, data: dict[str, Any], output: dict[str, Any], progress_data: dict[str, Any], ) -> None: if not isinstance(data, dict): raise TypeError( f"{self._name} Expected 'dict' type got '{type(data)}'" ) self._need_query = False value = data.get(self._name) if value is None: self._fake_children_parse() if self._name in data: output[self._name] = None return if not self._children: output[self._name] = value return output_value = output.get(self._name) if isinstance(value, dict): if output_value is None: output_value = {} output[self._name] = output_value for child in self._children: child.parse_result(value, output_value, progress_data) return if output_value is None: output_value = [] output[self._name] = output_value if not value: self._fake_children_parse() return diff = len(value) - len(output_value) if diff > 0: for _ in range(diff): output_value.append({}) for idx, item in enumerate(value): item_value = output_value[idx] for child in self._children: child.parse_result(item, item_value, progress_data)
[docs] def calculate_query(self) -> str: offset = self.indent * " " header = f"{offset}{self._name}{self._filters_to_string()}" if not self._children: return header output = [] output.append(header + " {") output.extend([ field.calculate_query() for field in self._children ]) output.append(offset + "}") return "\n".join(output)
[docs]class GraphQlQueryEdgeField(BaseGraphQlQueryField): has_edges = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._cursor = None self._edge_children = [] @property def child_indent(self) -> int: offset = self.offset * 2 return self.indent + offset def _children_iter(self) -> Generator[BaseGraphQlQueryField, None, None]: for child in super()._children_iter(): yield child for child in self._edge_children: yield child
[docs] def add_obj_field(self, field: BaseGraphQlQueryField) -> None: if field in self._edge_children: return super().add_obj_field(field)
[docs] def add_obj_edge_field(self, field: BaseGraphQlQueryField) -> None: if field in self._edge_children or field in self._children: return self._edge_children.append(field) field.set_parent(self)
[docs] def add_edge_field(self, name: str) -> GraphQlQueryField: item = GraphQlQueryField(name, self, self._order) self.add_obj_edge_field(item) return item
[docs] def reset_cursor(self) -> None: # Reset cursor only for edges self._cursor = None self._need_query = True super().reset_cursor()
[docs] def parse_result( self, data: dict[str, Any], output: dict[str, Any], progress_data: dict[str, Any], ) -> None: if not isinstance(data, dict): raise TypeError("{} Expected 'dict' type got '{}'".format( self._name, str(type(data)) )) value = data.get(self._name) if value is None: self._fake_children_parse() self._need_query = False return if self._name in output: node_values = output[self._name] else: node_values = [] output[self._name] = node_values nodes_by_cursor = {} handle_cursors = self.child_has_edges if handle_cursors: cursor_key = self._get_cursor_key() if cursor_key in progress_data: nodes_by_cursor = progress_data[cursor_key] else: progress_data[cursor_key] = nodes_by_cursor page_info = value["pageInfo"] new_cursor = page_info["endCursor"] self._need_query = page_info["hasNextPage"] edges = value["edges"] # Fake result parse if not edges: self._fake_children_parse() self._fetched_counter += len(edges) if self._limit and self._fetched_counter >= self._limit: self._need_query = False for edge in edges: if not handle_cursors: edge_value = {} node_values.append(edge_value) else: edge_cursor = edge["cursor"] edge_value = nodes_by_cursor.get(edge_cursor) if edge_value is None: edge_value = {} nodes_by_cursor[edge_cursor] = edge_value node_values.append(edge_value) for child in self._edge_children: child.parse_result(edge, edge_value, progress_data) for child in self._children: child.parse_result(edge["node"], edge_value, progress_data) if not self._need_query: return change_cursor = True for child in self._children_iter(): if child.need_query: change_cursor = False if change_cursor: for child in self._children_iter(): child.reset_cursor() self._cursor = new_cursor
def _get_cursor_key(self) -> str: return f"{self.path}/__cursor__"
[docs] def get_filters(self) -> dict[str, Any]: filters = super().get_filters() limit_key = "first" if self._order == SortOrder.descending: limit_key = "last" limit_amount = 300 if self._limit: total = self._fetched_counter + limit_amount if total > self._limit: limit_amount = self._limit - self._fetched_counter filters[limit_key] = limit_amount if self._cursor: filters["after"] = self._cursor return filters
[docs] def calculate_query(self) -> str: if not self._children and not self._edge_children: raise ValueError("Missing child definitions for edges {}".format( self.path )) offset = self.indent * " " header = f"{offset}{self._name}{self._filters_to_string()}" output = [] output.append(header + " {") edges_offset = offset + self.offset * " " node_offset = edges_offset + self.offset * " " output.append(edges_offset + "edges {") for field in self._edge_children: output.append(field.calculate_query()) if self._children: output.append(node_offset + "node {") for field in self._children: output.append( field.calculate_query() ) output.append(node_offset + "}") if self.child_has_edges: output.append(node_offset + "cursor") output.append(edges_offset + "}") # Add page information output.append(edges_offset + "pageInfo {") for page_key in ( "endCursor", "hasNextPage", ): output.append(node_offset + page_key) output.append(edges_offset + "}") output.append(offset + "}") return "\n".join(output)
INTROSPECTION_QUERY = """ query IntrospectionQuery { __schema { queryType { name } mutationType { name } subscriptionType { name } types { ...FullType } directives { name description locations args { ...InputValue } } } } fragment FullType on __Type { kind name description fields(includeDeprecated: true) { name description args { ...InputValue } type { ...TypeRef } isDeprecated deprecationReason } inputFields { ...InputValue } interfaces { ...TypeRef } enumValues(includeDeprecated: true) { name description isDeprecated deprecationReason } possibleTypes { ...TypeRef } } fragment InputValue on __InputValue { name description type { ...TypeRef } defaultValue } fragment TypeRef on __Type { kind name ofType { kind name ofType { kind name ofType { kind name ofType { kind name ofType { kind name ofType { kind name ofType { kind name } } } } } } } } """