# Copyright 2014 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log
from oslo_serialization import jsonutils
from sqlalchemy import orm
from keystone.common import sql
import keystone.conf
from keystone import exception
from keystone.federation.backends import base
from keystone.i18n import _
CONF = keystone.conf.CONF
LOG = log.getLogger(__name__)
[docs]
class FederationProtocolModel(sql.ModelBase, sql.ModelDictMixin):
__tablename__ = 'federation_protocol'
attributes = ['id', 'idp_id', 'mapping_id', 'remote_id_attribute']
mutable_attributes = frozenset(['mapping_id', 'remote_id_attribute'])
id = sql.Column(sql.String(64), primary_key=True)
idp_id = sql.Column(
sql.String(64),
sql.ForeignKey('identity_provider.id', ondelete='CASCADE'),
primary_key=True,
)
mapping_id = sql.Column(sql.String(64), nullable=False)
remote_id_attribute = sql.Column(sql.String(64))
[docs]
@classmethod
def from_dict(cls, dictionary):
new_dictionary = dictionary.copy()
return cls(**new_dictionary)
[docs]
def to_dict(self):
"""Return a dictionary with model's attributes."""
d = dict()
for attr in self.__class__.attributes:
d[attr] = getattr(self, attr)
return d
[docs]
class IdentityProviderModel(sql.ModelBase, sql.ModelDictMixin):
__tablename__ = 'identity_provider'
attributes = [
'id',
'domain_id',
'enabled',
'description',
'remote_ids',
'authorization_ttl',
]
mutable_attributes = frozenset(
['description', 'enabled', 'remote_ids', 'authorization_ttl']
)
id = sql.Column(sql.String(64), primary_key=True)
domain_id = sql.Column(sql.String(64), nullable=False)
enabled = sql.Column(sql.Boolean, nullable=False)
description = sql.Column(sql.Text(), nullable=True)
authorization_ttl = sql.Column(sql.Integer, nullable=True)
remote_ids = orm.relationship(
'IdPRemoteIdsModel',
order_by='IdPRemoteIdsModel.remote_id',
cascade='all, delete-orphan',
)
expiring_user_group_memberships = orm.relationship(
'ExpiringUserGroupMembership',
cascade='all, delete-orphan',
backref="idp",
)
[docs]
@classmethod
def from_dict(cls, dictionary):
new_dictionary = dictionary.copy()
remote_ids_list = new_dictionary.pop('remote_ids', None)
if not remote_ids_list:
remote_ids_list = []
identity_provider = cls(**new_dictionary)
remote_ids = []
# NOTE(fmarco76): the remote_ids_list contains only remote ids
# associated with the IdP because of the "relationship" established in
# sqlalchemy and corresponding to the FK in the idp_remote_ids table
for remote in remote_ids_list:
remote_ids.append(IdPRemoteIdsModel(remote_id=remote))
identity_provider.remote_ids = remote_ids
return identity_provider
[docs]
def to_dict(self):
"""Return a dictionary with model's attributes."""
d = dict()
for attr in self.__class__.attributes:
d[attr] = getattr(self, attr)
d['remote_ids'] = []
for remote in self.remote_ids:
d['remote_ids'].append(remote.remote_id)
return d
[docs]
class IdPRemoteIdsModel(sql.ModelBase, sql.ModelDictMixin):
__tablename__ = 'idp_remote_ids'
attributes = ['idp_id', 'remote_id']
mutable_attributes = frozenset(['idp_id', 'remote_id'])
idp_id = sql.Column(
sql.String(64),
sql.ForeignKey('identity_provider.id', ondelete='CASCADE'),
)
remote_id = sql.Column(sql.String(255), primary_key=True)
[docs]
@classmethod
def from_dict(cls, dictionary):
new_dictionary = dictionary.copy()
return cls(**new_dictionary)
[docs]
def to_dict(self):
"""Return a dictionary with model's attributes."""
d = dict()
for attr in self.__class__.attributes:
d[attr] = getattr(self, attr)
return d
[docs]
class MappingModel(sql.ModelBase, sql.ModelDictMixin):
__tablename__ = 'mapping'
attributes = ['id', 'rules', 'schema_version']
id = sql.Column(sql.String(64), primary_key=True)
rules = sql.Column(sql.JsonBlob(), nullable=False)
schema_version = sql.Column(
sql.String(5), nullable=False, server_default='1.0'
)
[docs]
@classmethod
def from_dict(cls, dictionary):
new_dictionary = dictionary.copy()
new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules'])
return cls(**new_dictionary)
[docs]
def to_dict(self):
"""Return a dictionary with model's attributes."""
d = dict()
for attr in self.__class__.attributes:
d[attr] = getattr(self, attr)
d['rules'] = jsonutils.loads(d['rules'])
return d
[docs]
class ServiceProviderModel(sql.ModelBase, sql.ModelDictMixin):
__tablename__ = 'service_provider'
attributes = [
'auth_url',
'id',
'enabled',
'description',
'relay_state_prefix',
'sp_url',
]
mutable_attributes = frozenset(
['auth_url', 'description', 'enabled', 'relay_state_prefix', 'sp_url']
)
id = sql.Column(sql.String(64), primary_key=True)
enabled = sql.Column(sql.Boolean, nullable=False)
description = sql.Column(sql.Text(), nullable=True)
auth_url = sql.Column(sql.String(256), nullable=False)
sp_url = sql.Column(sql.String(256), nullable=False)
relay_state_prefix = sql.Column(sql.String(256), nullable=False)
[docs]
@classmethod
def from_dict(cls, dictionary):
new_dictionary = dictionary.copy()
return cls(**new_dictionary)
[docs]
def to_dict(self):
"""Return a dictionary with model's attributes."""
d = dict()
for attr in self.__class__.attributes:
d[attr] = getattr(self, attr)
return d
[docs]
class Federation(base.FederationDriverBase):
_CONFLICT_LOG_MSG = 'Conflict %(conflict_type)s: %(details)s'
def _handle_idp_conflict(self, e):
conflict_type = 'identity_provider'
details = str(e)
LOG.debug(
self._CONFLICT_LOG_MSG,
{'conflict_type': conflict_type, 'details': details},
)
if 'remote_id' in details:
msg = _('Duplicate remote ID: %s')
else:
msg = _('Duplicate entry: %s')
msg = msg % e.value
raise exception.Conflict(type=conflict_type, details=msg)
# Identity Provider CRUD
[docs]
def create_idp(self, idp_id, idp):
idp['id'] = idp_id
try:
with sql.session_for_write() as session:
idp_ref = IdentityProviderModel.from_dict(idp)
session.add(idp_ref)
return idp_ref.to_dict()
except sql.DBDuplicateEntry as e:
self._handle_idp_conflict(e)
[docs]
def delete_idp(self, idp_id):
with sql.session_for_write() as session:
self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref)
def _get_idp(self, session, idp_id):
idp_ref = session.get(IdentityProviderModel, idp_id)
if not idp_ref:
raise exception.IdentityProviderNotFound(idp_id=idp_id)
return idp_ref
def _get_idp_from_remote_id(self, session, remote_id):
q = session.query(IdPRemoteIdsModel)
q = q.filter_by(remote_id=remote_id)
try:
return q.one()
except sql.NotFound:
raise exception.IdentityProviderNotFound(idp_id=remote_id)
[docs]
def list_idps(self, hints=None):
with sql.session_for_read() as session:
query = session.query(IdentityProviderModel)
idps = sql.filter_limit_query(IdentityProviderModel, query, hints)
idps_list = [idp.to_dict() for idp in idps]
return idps_list
[docs]
def get_idp(self, idp_id):
with sql.session_for_read() as session:
idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict()
[docs]
def get_idp_from_remote_id(self, remote_id):
with sql.session_for_read() as session:
ref = self._get_idp_from_remote_id(session, remote_id)
return ref.to_dict()
[docs]
def update_idp(self, idp_id, idp):
try:
with sql.session_for_write() as session:
idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict()
old_idp.update(idp)
new_idp = IdentityProviderModel.from_dict(old_idp)
for attr in IdentityProviderModel.mutable_attributes:
setattr(idp_ref, attr, getattr(new_idp, attr))
return idp_ref.to_dict()
except sql.DBDuplicateEntry as e:
self._handle_idp_conflict(e)
# Protocol CRUD
def _get_protocol(self, session, idp_id, protocol_id):
q = session.query(FederationProtocolModel)
q = q.filter_by(id=protocol_id, idp_id=idp_id)
try:
return q.one()
except sql.NotFound:
kwargs = {'protocol_id': protocol_id, 'idp_id': idp_id}
raise exception.FederatedProtocolNotFound(**kwargs)
[docs]
@sql.handle_conflicts(conflict_type='federation_protocol')
def create_protocol(self, idp_id, protocol_id, protocol):
protocol['id'] = protocol_id
protocol['idp_id'] = idp_id
with sql.session_for_write() as session:
self._get_idp(session, idp_id)
protocol_ref = FederationProtocolModel.from_dict(protocol)
session.add(protocol_ref)
return protocol_ref.to_dict()
[docs]
def update_protocol(self, idp_id, protocol_id, protocol):
with sql.session_for_write() as session:
proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict()
old_proto.update(protocol)
new_proto = FederationProtocolModel.from_dict(old_proto)
for attr in FederationProtocolModel.mutable_attributes:
setattr(proto_ref, attr, getattr(new_proto, attr))
return proto_ref.to_dict()
[docs]
def get_protocol(self, idp_id, protocol_id):
with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
return protocol_ref.to_dict()
[docs]
def list_protocols(self, idp_id):
with sql.session_for_read() as session:
q = session.query(FederationProtocolModel)
q = q.filter_by(idp_id=idp_id)
protocols = [protocol.to_dict() for protocol in q]
return protocols
[docs]
def delete_protocol(self, idp_id, protocol_id):
with sql.session_for_write() as session:
key_ref = self._get_protocol(session, idp_id, protocol_id)
session.delete(key_ref)
def _delete_assigned_protocols(self, session, idp_id):
query = session.query(FederationProtocolModel)
query = query.filter_by(idp_id=idp_id)
query.delete()
# Mapping CRUD
def _get_mapping(self, session, mapping_id):
mapping_ref = session.get(MappingModel, mapping_id)
if not mapping_ref:
raise exception.MappingNotFound(mapping_id=mapping_id)
return mapping_ref
[docs]
@sql.handle_conflicts(conflict_type='mapping')
def create_mapping(self, mapping_id, mapping):
ref = {}
ref['id'] = mapping_id
ref['rules'] = mapping.get('rules')
ref['schema_version'] = mapping.get('schema_version')
with sql.session_for_write() as session:
mapping_ref = MappingModel.from_dict(ref)
session.add(mapping_ref)
return mapping_ref.to_dict()
[docs]
def delete_mapping(self, mapping_id):
with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id)
session.delete(mapping_ref)
[docs]
def list_mappings(self):
with sql.session_for_read() as session:
mappings = session.query(MappingModel)
return [x.to_dict() for x in mappings]
[docs]
def get_mapping(self, mapping_id):
with sql.session_for_read() as session:
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
[docs]
@sql.handle_conflicts(conflict_type='mapping')
def update_mapping(self, mapping_id, mapping):
ref = {}
ref['id'] = mapping_id
ref['rules'] = mapping.get('rules')
if mapping.get('schema_version'):
ref['schema_version'] = mapping.get('schema_version')
with sql.session_for_write() as session:
mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict()
old_mapping.update(ref)
new_mapping = MappingModel.from_dict(old_mapping)
for attr in MappingModel.attributes:
setattr(mapping_ref, attr, getattr(new_mapping, attr))
return mapping_ref.to_dict()
[docs]
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
with sql.session_for_read() as session:
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
mapping_id = protocol_ref.mapping_id
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
# Service Provider CRUD
[docs]
@sql.handle_conflicts(conflict_type='service_provider')
def create_sp(self, sp_id, sp):
sp['id'] = sp_id
with sql.session_for_write() as session:
sp_ref = ServiceProviderModel.from_dict(sp)
session.add(sp_ref)
return sp_ref.to_dict()
[docs]
def delete_sp(self, sp_id):
with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id)
session.delete(sp_ref)
def _get_sp(self, session, sp_id):
sp_ref = session.get(ServiceProviderModel, sp_id)
if not sp_ref:
raise exception.ServiceProviderNotFound(sp_id=sp_id)
return sp_ref
[docs]
def list_sps(self, hints=None):
with sql.session_for_read() as session:
query = session.query(ServiceProviderModel)
sps = sql.filter_limit_query(ServiceProviderModel, query, hints)
sps_list = [sp.to_dict() for sp in sps]
return sps_list
[docs]
def get_sp(self, sp_id):
with sql.session_for_read() as session:
sp_ref = self._get_sp(session, sp_id)
return sp_ref.to_dict()
[docs]
def update_sp(self, sp_id, sp):
with sql.session_for_write() as session:
sp_ref = self._get_sp(session, sp_id)
old_sp = sp_ref.to_dict()
old_sp.update(sp)
new_sp = ServiceProviderModel.from_dict(old_sp)
for attr in ServiceProviderModel.mutable_attributes:
setattr(sp_ref, attr, getattr(new_sp, attr))
return sp_ref.to_dict()
[docs]
def get_enabled_service_providers(self):
with sql.session_for_read() as session:
service_providers = session.query(ServiceProviderModel)
service_providers = service_providers.filter_by(enabled=True)
return service_providers