Source code for keystone.federation.backends.sql

# Copyright 2014 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from oslo_log import log
from oslo_serialization import jsonutils
from sqlalchemy import orm

from keystone.common import sql
import keystone.conf
from keystone import exception
from keystone.federation.backends import base
from keystone.i18n import _

CONF = keystone.conf.CONF
LOG = log.getLogger(__name__)


[docs] class FederationProtocolModel(sql.ModelBase, sql.ModelDictMixin): __tablename__ = 'federation_protocol' attributes = ['id', 'idp_id', 'mapping_id', 'remote_id_attribute'] mutable_attributes = frozenset(['mapping_id', 'remote_id_attribute']) id = sql.Column(sql.String(64), primary_key=True) idp_id = sql.Column( sql.String(64), sql.ForeignKey('identity_provider.id', ondelete='CASCADE'), primary_key=True, ) mapping_id = sql.Column(sql.String(64), nullable=False) remote_id_attribute = sql.Column(sql.String(64))
[docs] @classmethod def from_dict(cls, dictionary): new_dictionary = dictionary.copy() return cls(**new_dictionary)
[docs] def to_dict(self): """Return a dictionary with model's attributes.""" d = dict() for attr in self.__class__.attributes: d[attr] = getattr(self, attr) return d
[docs] class IdentityProviderModel(sql.ModelBase, sql.ModelDictMixin): __tablename__ = 'identity_provider' attributes = [ 'id', 'domain_id', 'enabled', 'description', 'remote_ids', 'authorization_ttl', ] mutable_attributes = frozenset( ['description', 'enabled', 'remote_ids', 'authorization_ttl'] ) id = sql.Column(sql.String(64), primary_key=True) domain_id = sql.Column(sql.String(64), nullable=False) enabled = sql.Column(sql.Boolean, nullable=False) description = sql.Column(sql.Text(), nullable=True) authorization_ttl = sql.Column(sql.Integer, nullable=True) remote_ids = orm.relationship( 'IdPRemoteIdsModel', order_by='IdPRemoteIdsModel.remote_id', cascade='all, delete-orphan', ) expiring_user_group_memberships = orm.relationship( 'ExpiringUserGroupMembership', cascade='all, delete-orphan', backref="idp", )
[docs] @classmethod def from_dict(cls, dictionary): new_dictionary = dictionary.copy() remote_ids_list = new_dictionary.pop('remote_ids', None) if not remote_ids_list: remote_ids_list = [] identity_provider = cls(**new_dictionary) remote_ids = [] # NOTE(fmarco76): the remote_ids_list contains only remote ids # associated with the IdP because of the "relationship" established in # sqlalchemy and corresponding to the FK in the idp_remote_ids table for remote in remote_ids_list: remote_ids.append(IdPRemoteIdsModel(remote_id=remote)) identity_provider.remote_ids = remote_ids return identity_provider
[docs] def to_dict(self): """Return a dictionary with model's attributes.""" d = dict() for attr in self.__class__.attributes: d[attr] = getattr(self, attr) d['remote_ids'] = [] for remote in self.remote_ids: d['remote_ids'].append(remote.remote_id) return d
[docs] class IdPRemoteIdsModel(sql.ModelBase, sql.ModelDictMixin): __tablename__ = 'idp_remote_ids' attributes = ['idp_id', 'remote_id'] mutable_attributes = frozenset(['idp_id', 'remote_id']) idp_id = sql.Column( sql.String(64), sql.ForeignKey('identity_provider.id', ondelete='CASCADE'), ) remote_id = sql.Column(sql.String(255), primary_key=True)
[docs] @classmethod def from_dict(cls, dictionary): new_dictionary = dictionary.copy() return cls(**new_dictionary)
[docs] def to_dict(self): """Return a dictionary with model's attributes.""" d = dict() for attr in self.__class__.attributes: d[attr] = getattr(self, attr) return d
[docs] class MappingModel(sql.ModelBase, sql.ModelDictMixin): __tablename__ = 'mapping' attributes = ['id', 'rules', 'schema_version'] id = sql.Column(sql.String(64), primary_key=True) rules = sql.Column(sql.JsonBlob(), nullable=False) schema_version = sql.Column( sql.String(5), nullable=False, server_default='1.0' )
[docs] @classmethod def from_dict(cls, dictionary): new_dictionary = dictionary.copy() new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules']) return cls(**new_dictionary)
[docs] def to_dict(self): """Return a dictionary with model's attributes.""" d = dict() for attr in self.__class__.attributes: d[attr] = getattr(self, attr) d['rules'] = jsonutils.loads(d['rules']) return d
[docs] class ServiceProviderModel(sql.ModelBase, sql.ModelDictMixin): __tablename__ = 'service_provider' attributes = [ 'auth_url', 'id', 'enabled', 'description', 'relay_state_prefix', 'sp_url', ] mutable_attributes = frozenset( ['auth_url', 'description', 'enabled', 'relay_state_prefix', 'sp_url'] ) id = sql.Column(sql.String(64), primary_key=True) enabled = sql.Column(sql.Boolean, nullable=False) description = sql.Column(sql.Text(), nullable=True) auth_url = sql.Column(sql.String(256), nullable=False) sp_url = sql.Column(sql.String(256), nullable=False) relay_state_prefix = sql.Column(sql.String(256), nullable=False)
[docs] @classmethod def from_dict(cls, dictionary): new_dictionary = dictionary.copy() return cls(**new_dictionary)
[docs] def to_dict(self): """Return a dictionary with model's attributes.""" d = dict() for attr in self.__class__.attributes: d[attr] = getattr(self, attr) return d
[docs] class Federation(base.FederationDriverBase): _CONFLICT_LOG_MSG = 'Conflict %(conflict_type)s: %(details)s' def _handle_idp_conflict(self, e): conflict_type = 'identity_provider' details = str(e) LOG.debug( self._CONFLICT_LOG_MSG, {'conflict_type': conflict_type, 'details': details}, ) if 'remote_id' in details: msg = _('Duplicate remote ID: %s') else: msg = _('Duplicate entry: %s') msg = msg % e.value raise exception.Conflict(type=conflict_type, details=msg) # Identity Provider CRUD
[docs] def create_idp(self, idp_id, idp): idp['id'] = idp_id try: with sql.session_for_write() as session: idp_ref = IdentityProviderModel.from_dict(idp) session.add(idp_ref) return idp_ref.to_dict() except sql.DBDuplicateEntry as e: self._handle_idp_conflict(e)
[docs] def delete_idp(self, idp_id): with sql.session_for_write() as session: self._delete_assigned_protocols(session, idp_id) idp_ref = self._get_idp(session, idp_id) session.delete(idp_ref)
def _get_idp(self, session, idp_id): idp_ref = session.get(IdentityProviderModel, idp_id) if not idp_ref: raise exception.IdentityProviderNotFound(idp_id=idp_id) return idp_ref def _get_idp_from_remote_id(self, session, remote_id): q = session.query(IdPRemoteIdsModel) q = q.filter_by(remote_id=remote_id) try: return q.one() except sql.NotFound: raise exception.IdentityProviderNotFound(idp_id=remote_id)
[docs] def list_idps(self, hints=None): with sql.session_for_read() as session: query = session.query(IdentityProviderModel) idps = sql.filter_limit_query(IdentityProviderModel, query, hints) idps_list = [idp.to_dict() for idp in idps] return idps_list
[docs] def get_idp(self, idp_id): with sql.session_for_read() as session: idp_ref = self._get_idp(session, idp_id) return idp_ref.to_dict()
[docs] def get_idp_from_remote_id(self, remote_id): with sql.session_for_read() as session: ref = self._get_idp_from_remote_id(session, remote_id) return ref.to_dict()
[docs] def update_idp(self, idp_id, idp): try: with sql.session_for_write() as session: idp_ref = self._get_idp(session, idp_id) old_idp = idp_ref.to_dict() old_idp.update(idp) new_idp = IdentityProviderModel.from_dict(old_idp) for attr in IdentityProviderModel.mutable_attributes: setattr(idp_ref, attr, getattr(new_idp, attr)) return idp_ref.to_dict() except sql.DBDuplicateEntry as e: self._handle_idp_conflict(e)
# Protocol CRUD def _get_protocol(self, session, idp_id, protocol_id): q = session.query(FederationProtocolModel) q = q.filter_by(id=protocol_id, idp_id=idp_id) try: return q.one() except sql.NotFound: kwargs = {'protocol_id': protocol_id, 'idp_id': idp_id} raise exception.FederatedProtocolNotFound(**kwargs)
[docs] @sql.handle_conflicts(conflict_type='federation_protocol') def create_protocol(self, idp_id, protocol_id, protocol): protocol['id'] = protocol_id protocol['idp_id'] = idp_id with sql.session_for_write() as session: self._get_idp(session, idp_id) protocol_ref = FederationProtocolModel.from_dict(protocol) session.add(protocol_ref) return protocol_ref.to_dict()
[docs] def update_protocol(self, idp_id, protocol_id, protocol): with sql.session_for_write() as session: proto_ref = self._get_protocol(session, idp_id, protocol_id) old_proto = proto_ref.to_dict() old_proto.update(protocol) new_proto = FederationProtocolModel.from_dict(old_proto) for attr in FederationProtocolModel.mutable_attributes: setattr(proto_ref, attr, getattr(new_proto, attr)) return proto_ref.to_dict()
[docs] def get_protocol(self, idp_id, protocol_id): with sql.session_for_read() as session: protocol_ref = self._get_protocol(session, idp_id, protocol_id) return protocol_ref.to_dict()
[docs] def list_protocols(self, idp_id): with sql.session_for_read() as session: q = session.query(FederationProtocolModel) q = q.filter_by(idp_id=idp_id) protocols = [protocol.to_dict() for protocol in q] return protocols
[docs] def delete_protocol(self, idp_id, protocol_id): with sql.session_for_write() as session: key_ref = self._get_protocol(session, idp_id, protocol_id) session.delete(key_ref)
def _delete_assigned_protocols(self, session, idp_id): query = session.query(FederationProtocolModel) query = query.filter_by(idp_id=idp_id) query.delete() # Mapping CRUD def _get_mapping(self, session, mapping_id): mapping_ref = session.get(MappingModel, mapping_id) if not mapping_ref: raise exception.MappingNotFound(mapping_id=mapping_id) return mapping_ref
[docs] @sql.handle_conflicts(conflict_type='mapping') def create_mapping(self, mapping_id, mapping): ref = {} ref['id'] = mapping_id ref['rules'] = mapping.get('rules') ref['schema_version'] = mapping.get('schema_version') with sql.session_for_write() as session: mapping_ref = MappingModel.from_dict(ref) session.add(mapping_ref) return mapping_ref.to_dict()
[docs] def delete_mapping(self, mapping_id): with sql.session_for_write() as session: mapping_ref = self._get_mapping(session, mapping_id) session.delete(mapping_ref)
[docs] def list_mappings(self): with sql.session_for_read() as session: mappings = session.query(MappingModel) return [x.to_dict() for x in mappings]
[docs] def get_mapping(self, mapping_id): with sql.session_for_read() as session: mapping_ref = self._get_mapping(session, mapping_id) return mapping_ref.to_dict()
[docs] @sql.handle_conflicts(conflict_type='mapping') def update_mapping(self, mapping_id, mapping): ref = {} ref['id'] = mapping_id ref['rules'] = mapping.get('rules') if mapping.get('schema_version'): ref['schema_version'] = mapping.get('schema_version') with sql.session_for_write() as session: mapping_ref = self._get_mapping(session, mapping_id) old_mapping = mapping_ref.to_dict() old_mapping.update(ref) new_mapping = MappingModel.from_dict(old_mapping) for attr in MappingModel.attributes: setattr(mapping_ref, attr, getattr(new_mapping, attr)) return mapping_ref.to_dict()
[docs] def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): with sql.session_for_read() as session: protocol_ref = self._get_protocol(session, idp_id, protocol_id) mapping_id = protocol_ref.mapping_id mapping_ref = self._get_mapping(session, mapping_id) return mapping_ref.to_dict()
# Service Provider CRUD
[docs] @sql.handle_conflicts(conflict_type='service_provider') def create_sp(self, sp_id, sp): sp['id'] = sp_id with sql.session_for_write() as session: sp_ref = ServiceProviderModel.from_dict(sp) session.add(sp_ref) return sp_ref.to_dict()
[docs] def delete_sp(self, sp_id): with sql.session_for_write() as session: sp_ref = self._get_sp(session, sp_id) session.delete(sp_ref)
def _get_sp(self, session, sp_id): sp_ref = session.get(ServiceProviderModel, sp_id) if not sp_ref: raise exception.ServiceProviderNotFound(sp_id=sp_id) return sp_ref
[docs] def list_sps(self, hints=None): with sql.session_for_read() as session: query = session.query(ServiceProviderModel) sps = sql.filter_limit_query(ServiceProviderModel, query, hints) sps_list = [sp.to_dict() for sp in sps] return sps_list
[docs] def get_sp(self, sp_id): with sql.session_for_read() as session: sp_ref = self._get_sp(session, sp_id) return sp_ref.to_dict()
[docs] def update_sp(self, sp_id, sp): with sql.session_for_write() as session: sp_ref = self._get_sp(session, sp_id) old_sp = sp_ref.to_dict() old_sp.update(sp) new_sp = ServiceProviderModel.from_dict(old_sp) for attr in ServiceProviderModel.mutable_attributes: setattr(sp_ref, attr, getattr(new_sp, attr)) return sp_ref.to_dict()
[docs] def get_enabled_service_providers(self): with sql.session_for_read() as session: service_providers = session.query(ServiceProviderModel) service_providers = service_providers.filter_by(enabled=True) return service_providers