Changeset - 9753e0907827
[Not reviewed]
beta
0 2 55
Marcin Kuzminski - 15 years ago 2010-12-11 01:54:12
marcin@python-works.com
added dbmigrate package, added model changes
moved out upgrade db command to that package
57 files changed with 5091 insertions and 30 deletions:
0 comments (0 inline, 0 general)
rhodecode/lib/dbmigrate/__init__.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
"""
 
    rhodecode.lib.dbmigrate.__init__
 
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
    
 
    Database migration modules
 
    
 
    :created_on: Dec 11, 2010
 
    :author: marcink
 
    :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>    
 
    :license: GPLv3, see COPYING for more details.
 
"""
 
# This program is free software; you can redistribute it and/or
 
# modify it under the terms of the GNU General Public License
 
# as published by the Free Software Foundation; version 2
 
# of the License or (at your opinion) any later version of the license.
 
# 
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
# 
 
# You should have received a copy of the GNU General Public License
 
# along with this program; if not, write to the Free Software
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 
# MA  02110-1301, USA.
 

	
 
from rhodecode.lib.utils import BasePasterCommand
 
from rhodecode.lib.utils import BasePasterCommand, Command, add_cache
 

	
 
from sqlalchemy import engine_from_config
 

	
 
class UpgradeDb(BasePasterCommand):
 
    """Command used for paster to upgrade our database to newer version
 
    """
 

	
 
    max_args = 1
 
    min_args = 1
 

	
 
    usage = "CONFIG_FILE"
 
    summary = "Upgrades current db to newer version given configuration file"
 
    group_name = "RhodeCode"
 

	
 
    parser = Command.standard_parser(verbose=True)
 

	
 
    def command(self):
 
        from pylons import config
 
        add_cache(config)
 
        engine = engine_from_config(config, 'sqlalchemy.db1.')
 
        print engine
 
        raise NotImplementedError('Not implemented yet')
 

	
 

	
 
    def update_parser(self):
 
        self.parser.add_option('--sql',
 
                      action='store_true',
 
                      dest='just_sql',
 
                      help="Prints upgrade sql for further investigation",
 
                      default=False)
rhodecode/lib/dbmigrate/migrate/__init__.py
Show inline comments
 
new file 100644
 
"""
 
   SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
 
   database schema version and repository management and
 
   :mod:`migrate.changeset` that allows to define database schema changes
 
   using Python.
 
"""
 

	
 
from migrate.versioning import *
 
from migrate.changeset import *
rhodecode/lib/dbmigrate/migrate/changeset/__init__.py
Show inline comments
 
new file 100644
 
"""
 
   This module extends SQLAlchemy and provides additional DDL [#]_
 
   support.
 

	
 
   .. [#] SQL Data Definition Language
 
"""
 
import re
 
import warnings
 

	
 
import sqlalchemy
 
from sqlalchemy import __version__ as _sa_version
 

	
 
warnings.simplefilter('always', DeprecationWarning)
 

	
 
_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
 
SQLA_06 = _sa_version >= (0, 6)
 

	
 
del re
 
del _sa_version
 

	
 
from migrate.changeset.schema import *
 
from migrate.changeset.constraint import *
 

	
 
sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
 
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
 
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
 

	
 
sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, )
rhodecode/lib/dbmigrate/migrate/changeset/ansisql.py
Show inline comments
 
new file 100644
 
"""
 
   Extensions to SQLAlchemy for altering existing tables.
 

	
 
   At the moment, this isn't so much based off of ANSI as much as
 
   things that just happen to work with multiple databases.
 
"""
 
import StringIO
 

	
 
import sqlalchemy as sa
 
from sqlalchemy.schema import SchemaVisitor
 
from sqlalchemy.engine.default import DefaultDialect
 
from sqlalchemy.sql import ClauseElement
 
from sqlalchemy.schema import (ForeignKeyConstraint,
 
                               PrimaryKeyConstraint,
 
                               CheckConstraint,
 
                               UniqueConstraint,
 
                               Index)
 

	
 
from migrate import exceptions
 
from migrate.changeset import constraint, SQLA_06
 

	
 
if not SQLA_06:
 
    from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
 
else:
 
    from sqlalchemy.schema import AddConstraint, DropConstraint
 
    from sqlalchemy.sql.compiler import DDLCompiler
 
    SchemaGenerator = SchemaDropper = DDLCompiler
 

	
 

	
 
class AlterTableVisitor(SchemaVisitor):
 
    """Common operations for ``ALTER TABLE`` statements."""
 

	
 
    if SQLA_06:
 
        # engine.Compiler looks for .statement
 
        # when it spawns off a new compiler
 
        statement = ClauseElement()
 

	
 
    def append(self, s):
 
        """Append content to the SchemaIterator's query buffer."""
 

	
 
        self.buffer.write(s)
 

	
 
    def execute(self):
 
        """Execute the contents of the SchemaIterator's buffer."""
 
        try:
 
            return self.connection.execute(self.buffer.getvalue())
 
        finally:
 
            self.buffer.truncate(0)
 

	
 
    def __init__(self, dialect, connection, **kw):
 
        self.connection = connection
 
        self.buffer = StringIO.StringIO()
 
        self.preparer = dialect.identifier_preparer
 
        self.dialect = dialect
 

	
 
    def traverse_single(self, elem):
 
        ret = super(AlterTableVisitor, self).traverse_single(elem)
 
        if ret:
 
            # adapt to 0.6 which uses a string-returning
 
            # object
 
            self.append(" %s" % ret)
 
            
 
    def _to_table(self, param):
 
        """Returns the table object for the given param object."""
 
        if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
 
            ret = param.table
 
        else:
 
            ret = param
 
        return ret
 

	
 
    def start_alter_table(self, param):
 
        """Returns the start of an ``ALTER TABLE`` SQL-Statement.
 

	
 
        Use the param object to determine the table name and use it
 
        for building the SQL statement.
 

	
 
        :param param: object to determine the table from
 
        :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
 
          :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
 
          or string (table name)
 
        """
 
        table = self._to_table(param)
 
        self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
 
        return table
 

	
 

	
 
class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
 
    """Extends ansisql generator for column creation (alter table add col)"""
 

	
 
    def visit_column(self, column):
 
        """Create a column (table already exists).
 

	
 
        :param column: column object
 
        :type column: :class:`sqlalchemy.Column` instance
 
        """
 
        if column.default is not None:
 
            self.traverse_single(column.default)
 

	
 
        table = self.start_alter_table(column)
 
        self.append("ADD ")
 
        self.append(self.get_column_specification(column))
 

	
 
        for cons in column.constraints:
 
            self.traverse_single(cons)
 
        self.execute()
 

	
 
        # ALTER TABLE STATEMENTS
 

	
 
        # add indexes and unique constraints
 
        if column.index_name:
 
            Index(column.index_name,column).create()
 
        elif column.unique_name:
 
            constraint.UniqueConstraint(column,
 
                                        name=column.unique_name).create()
 

	
 
        # SA bounds FK constraints to table, add manually
 
        for fk in column.foreign_keys:
 
            self.add_foreignkey(fk.constraint)
 

	
 
        # add primary key constraint if needed
 
        if column.primary_key_name:
 
            cons = constraint.PrimaryKeyConstraint(column,
 
                                                   name=column.primary_key_name)
 
            cons.create()
 

	
 
    if SQLA_06:
 
        def add_foreignkey(self, fk):
 
            self.connection.execute(AddConstraint(fk))
 

	
 
class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
 
    """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
 
    DROP COLUMN``).
 
    """
 

	
 
    def visit_column(self, column):
 
        """Drop a column from its table.
 

	
 
        :param column: the column object
 
        :type column: :class:`sqlalchemy.Column`
 
        """
 
        table = self.start_alter_table(column)
 
        self.append('DROP COLUMN %s' % self.preparer.format_column(column))
 
        self.execute()
 

	
 

	
 
class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
 
    """Manages changes to existing schema elements.
 

	
 
    Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
 
    is in SchemaGenerator.
 

	
 
    All items may be renamed. Columns can also have many of their properties -
 
    type, for example - changed.
 

	
 
    Each function is passed a tuple, containing (object, name); where
 
    object is a type of object you'd expect for that function
 
    (ie. table for visit_table) and name is the object's new
 
    name. NONE means the name is unchanged.
 
    """
 

	
 
    def visit_table(self, table):
 
        """Rename a table. Other ops aren't supported."""
 
        self.start_alter_table(table)
 
        self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
 
                                                         table.quote))
 
        self.execute()
 

	
 
    def visit_index(self, index):
 
        """Rename an index"""
 
        if hasattr(self, '_validate_identifier'):
 
            # SA <= 0.6.3
 
            self.append("ALTER INDEX %s RENAME TO %s" % (
 
                    self.preparer.quote(
 
                        self._validate_identifier(
 
                            index.name, True), index.quote),
 
                    self.preparer.quote(
 
                        self._validate_identifier(
 
                            index.new_name, True), index.quote)))
 
        else:
 
            # SA >= 0.6.5
 
            self.append("ALTER INDEX %s RENAME TO %s" % (
 
                    self.preparer.quote(
 
                        self._index_identifier(
 
                            index.name), index.quote),
 
                    self.preparer.quote(
 
                        self._index_identifier(
 
                            index.new_name), index.quote)))
 
        self.execute()
 

	
 
    def visit_column(self, delta):
 
        """Rename/change a column."""
 
        # ALTER COLUMN is implemented as several ALTER statements
 
        keys = delta.keys()
 
        if 'type' in keys:
 
            self._run_subvisit(delta, self._visit_column_type)
 
        if 'nullable' in keys:
 
            self._run_subvisit(delta, self._visit_column_nullable)
 
        if 'server_default' in keys:
 
            # Skip 'default': only handle server-side defaults, others
 
            # are managed by the app, not the db.
 
            self._run_subvisit(delta, self._visit_column_default)
 
        if 'name' in keys:
 
            self._run_subvisit(delta, self._visit_column_name, start_alter=False)
 

	
 
    def _run_subvisit(self, delta, func, start_alter=True):
 
        """Runs visit method based on what needs to be changed on column"""
 
        table = self._to_table(delta.table)
 
        col_name = delta.current_name
 
        if start_alter:
 
            self.start_alter_column(table, col_name)
 
        ret = func(table, delta.result_column, delta)
 
        self.execute()
 

	
 
    def start_alter_column(self, table, col_name):
 
        """Starts ALTER COLUMN"""
 
        self.start_alter_table(table)
 
        self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
 

	
 
    def _visit_column_nullable(self, table, column, delta):
 
        nullable = delta['nullable']
 
        if nullable:
 
            self.append("DROP NOT NULL")
 
        else:
 
            self.append("SET NOT NULL")
 

	
 
    def _visit_column_default(self, table, column, delta):
 
        default_text = self.get_column_default_string(column)
 
        if default_text is not None:
 
            self.append("SET DEFAULT %s" % default_text)
 
        else:
 
            self.append("DROP DEFAULT")
 

	
 
    def _visit_column_type(self, table, column, delta):
 
        type_ = delta['type']
 
        if SQLA_06:
 
            type_text = str(type_.compile(dialect=self.dialect))
 
        else:
 
            type_text = type_.dialect_impl(self.dialect).get_col_spec()
 
        self.append("TYPE %s" % type_text)
 

	
 
    def _visit_column_name(self, table, column, delta):
 
        self.start_alter_table(table)
 
        col_name = self.preparer.quote(delta.current_name, table.quote)
 
        new_name = self.preparer.format_column(delta.result_column)
 
        self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
 

	
 

	
 
class ANSIConstraintCommon(AlterTableVisitor):
 
    """
 
    Migrate's constraints require a separate creation function from
 
    SA's: Migrate's constraints are created independently of a table;
 
    SA's are created at the same time as the table.
 
    """
 

	
 
    def get_constraint_name(self, cons):
 
        """Gets a name for the given constraint.
 

	
 
        If the name is already set it will be used otherwise the
 
        constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
 
        method is used.
 

	
 
        :param cons: constraint object
 
        """
 
        if cons.name is not None:
 
            ret = cons.name
 
        else:
 
            ret = cons.name = cons.autoname()
 
        return self.preparer.quote(ret, cons.quote)
 

	
 
    def visit_migrate_primary_key_constraint(self, *p, **k):
 
        self._visit_constraint(*p, **k)
 

	
 
    def visit_migrate_foreign_key_constraint(self, *p, **k):
 
        self._visit_constraint(*p, **k)
 

	
 
    def visit_migrate_check_constraint(self, *p, **k):
 
        self._visit_constraint(*p, **k)
 

	
 
    def visit_migrate_unique_constraint(self, *p, **k):
 
        self._visit_constraint(*p, **k)
 

	
 
if SQLA_06:
 
    class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
 
        def _visit_constraint(self, constraint):
 
            constraint.name = self.get_constraint_name(constraint)
 
            self.append(self.process(AddConstraint(constraint)))
 
            self.execute()
 

	
 
    class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
 
        def _visit_constraint(self, constraint):
 
            constraint.name = self.get_constraint_name(constraint)
 
            self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
 
            self.execute()
 

	
 
else:
 
    class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
 

	
 
        def get_constraint_specification(self, cons, **kwargs):
 
            """Constaint SQL generators.
 
        
 
            We cannot use SA visitors because they append comma.
 
            """
 
        
 
            if isinstance(cons, PrimaryKeyConstraint):
 
                if cons.name is not None:
 
                    self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
 
                self.append("PRIMARY KEY ")
 
                self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
 
                                               for c in cons))
 
                self.define_constraint_deferrability(cons)
 
            elif isinstance(cons, ForeignKeyConstraint):
 
                self.define_foreign_key(cons)
 
            elif isinstance(cons, CheckConstraint):
 
                if cons.name is not None:
 
                    self.append("CONSTRAINT %s " %
 
                                self.preparer.format_constraint(cons))
 
                self.append("CHECK (%s)" % cons.sqltext)
 
                self.define_constraint_deferrability(cons)
 
            elif isinstance(cons, UniqueConstraint):
 
                if cons.name is not None:
 
                    self.append("CONSTRAINT %s " %
 
                                self.preparer.format_constraint(cons))
 
                self.append("UNIQUE (%s)" % \
 
                    (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
 
                self.define_constraint_deferrability(cons)
 
            else:
 
                raise exceptions.InvalidConstraintError(cons)
 

	
 
        def _visit_constraint(self, constraint):
 
        
 
            table = self.start_alter_table(constraint)
 
            constraint.name = self.get_constraint_name(constraint)
 
            self.append("ADD ")
 
            self.get_constraint_specification(constraint)
 
            self.execute()
 
    
 

	
 
    class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
 

	
 
        def _visit_constraint(self, constraint):
 
            self.start_alter_table(constraint)
 
            self.append("DROP CONSTRAINT ")
 
            constraint.name = self.get_constraint_name(constraint)
 
            self.append(self.preparer.format_constraint(constraint))
 
            if constraint.cascade:
 
                self.cascade_constraint(constraint)
 
            self.execute()
 

	
 
        def cascade_constraint(self, constraint):
 
            self.append(" CASCADE")
 

	
 

	
 
class ANSIDialect(DefaultDialect):
 
    columngenerator = ANSIColumnGenerator
 
    columndropper = ANSIColumnDropper
 
    schemachanger = ANSISchemaChanger
 
    constraintgenerator = ANSIConstraintGenerator
 
    constraintdropper = ANSIConstraintDropper
rhodecode/lib/dbmigrate/migrate/changeset/constraint.py
Show inline comments
 
new file 100644
 
"""
 
   This module defines standalone schema constraint classes.
 
"""
 
from sqlalchemy import schema
 

	
 
from migrate.exceptions import *
 
from migrate.changeset import SQLA_06
 

	
 
class ConstraintChangeset(object):
 
    """Base class for Constraint classes."""
 

	
 
    def _normalize_columns(self, cols, table_name=False):
 
        """Given: column objects or names; return col names and
 
        (maybe) a table"""
 
        colnames = []
 
        table = None
 
        for col in cols:
 
            if isinstance(col, schema.Column):
 
                if col.table is not None and table is None:
 
                    table = col.table
 
                if table_name:
 
                    col = '.'.join((col.table.name, col.name))
 
                else:
 
                    col = col.name
 
            colnames.append(col)
 
        return colnames, table
 

	
 
    def __do_imports(self, visitor_name, *a, **kw):
 
        engine = kw.pop('engine', self.table.bind)
 
        from migrate.changeset.databases.visitor import (get_engine_visitor,
 
                                                         run_single_visitor)
 
        visitorcallable = get_engine_visitor(engine, visitor_name)
 
        run_single_visitor(engine, visitorcallable, self, *a, **kw)
 

	
 
    def create(self, *a, **kw):
 
        """Create the constraint in the database.
 

	
 
        :param engine: the database engine to use. If this is \
 
        :keyword:`None` the instance's engine will be used
 
        :type engine: :class:`sqlalchemy.engine.base.Engine`
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        # TODO: set the parent here instead of in __init__
 
        self.__do_imports('constraintgenerator', *a, **kw)
 

	
 
    def drop(self, *a, **kw):
 
        """Drop the constraint from the database.
 

	
 
        :param engine: the database engine to use. If this is
 
          :keyword:`None` the instance's engine will be used
 
        :param cascade: Issue CASCADE drop if database supports it
 
        :type engine: :class:`sqlalchemy.engine.base.Engine`
 
        :type cascade: bool
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        :returns: Instance with cleared columns
 
        """
 
        self.cascade = kw.pop('cascade', False)
 
        self.__do_imports('constraintdropper', *a, **kw)
 
        # the spirit of Constraint objects is that they
 
        # are immutable (just like in a DB.  they're only ADDed
 
        # or DROPped).
 
        #self.columns.clear()
 
        return self
 

	
 

	
 
class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
 
    """Construct PrimaryKeyConstraint
 

	
 
    Migrate's additional parameters:
 

	
 
    :param cols: Columns in constraint.
 
    :param table: If columns are passed as strings, this kw is required
 
    :type table: Table instance
 
    :type cols: strings or Column instances
 
    """
 

	
 
    __migrate_visit_name__ = 'migrate_primary_key_constraint'
 

	
 
    def __init__(self, *cols, **kwargs):
 
        colnames, table = self._normalize_columns(cols)
 
        table = kwargs.pop('table', table)
 
        super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
 
        if table is not None:
 
            self._set_parent(table)
 

	
 

	
 
    def autoname(self):
 
        """Mimic the database's automatic constraint names"""
 
        return "%s_pkey" % self.table.name
 

	
 

	
 
class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
 
    """Construct ForeignKeyConstraint
 

	
 
    Migrate's additional parameters:
 

	
 
    :param columns: Columns in constraint
 
    :param refcolumns: Columns that this FK reffers to in another table.
 
    :param table: If columns are passed as strings, this kw is required
 
    :type table: Table instance
 
    :type columns: list of strings or Column instances
 
    :type refcolumns: list of strings or Column instances
 
    """
 

	
 
    __migrate_visit_name__ = 'migrate_foreign_key_constraint'
 

	
 
    def __init__(self, columns, refcolumns, *args, **kwargs):
 
        colnames, table = self._normalize_columns(columns)
 
        table = kwargs.pop('table', table)
 
        refcolnames, reftable = self._normalize_columns(refcolumns,
 
                                                        table_name=True)
 
        super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args,
 
                                                   **kwargs)
 
        if table is not None:
 
            self._set_parent(table)
 

	
 
    @property
 
    def referenced(self):
 
        return [e.column for e in self.elements]
 

	
 
    @property
 
    def reftable(self):
 
        return self.referenced[0].table
 

	
 
    def autoname(self):
 
        """Mimic the database's automatic constraint names"""
 
        if hasattr(self.columns, 'keys'):
 
            # SA <= 0.5
 
            firstcol = self.columns[self.columns.keys()[0]]
 
            ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
 
                table=firstcol.table.name,
 
                firstcolumn=firstcol.name,)
 
        else:
 
            # SA >= 0.6
 
            ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
 
                table=self.table.name,
 
                firstcolumn=self.columns[0],)
 
        return ret
 

	
 

	
 
class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
 
    """Construct CheckConstraint
 

	
 
    Migrate's additional parameters:
 

	
 
    :param sqltext: Plain SQL text to check condition
 
    :param columns: If not name is applied, you must supply this kw\
 
    to autoname constraint
 
    :param table: If columns are passed as strings, this kw is required
 
    :type table: Table instance
 
    :type columns: list of Columns instances
 
    :type sqltext: string
 
    """
 

	
 
    __migrate_visit_name__ = 'migrate_check_constraint'
 

	
 
    def __init__(self, sqltext, *args, **kwargs):
 
        cols = kwargs.pop('columns', [])
 
        if not cols and not kwargs.get('name', False):
 
            raise InvalidConstraintError('You must either set "name"'
 
                'parameter or "columns" to autogenarate it.')
 
        colnames, table = self._normalize_columns(cols)
 
        table = kwargs.pop('table', table)
 
        schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
 
        if table is not None:
 
            if not SQLA_06:
 
                self.table = table
 
            self._set_parent(table)
 
        self.colnames = colnames
 

	
 
    def autoname(self):
 
        return "%(table)s_%(cols)s_check" % \
 
            dict(table=self.table.name, cols="_".join(self.colnames))
 

	
 

	
 
class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
 
    """Construct UniqueConstraint
 

	
 
    Migrate's additional parameters:
 

	
 
    :param cols: Columns in constraint.
 
    :param table: If columns are passed as strings, this kw is required
 
    :type table: Table instance
 
    :type cols: strings or Column instances
 

	
 
    .. versionadded:: 0.6.0
 
    """
 

	
 
    __migrate_visit_name__ = 'migrate_unique_constraint'
 

	
 
    def __init__(self, *cols, **kwargs):
 
        self.colnames, table = self._normalize_columns(cols)
 
        table = kwargs.pop('table', table)
 
        super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
 
        if table is not None:
 
            self._set_parent(table)
 

	
 
    def autoname(self):
 
        """Mimic the database's automatic constraint names"""
 
        return "%s_%s_key" % (self.table.name, self.colnames[0])
rhodecode/lib/dbmigrate/migrate/changeset/databases/__init__.py
Show inline comments
 
new file 100644
 
"""
 
   This module contains database dialect specific changeset
 
   implementations.
 
"""
 
__all__ = [
 
    'postgres',
 
    'sqlite',
 
    'mysql',
 
    'oracle',
 
]
rhodecode/lib/dbmigrate/migrate/changeset/databases/firebird.py
Show inline comments
 
new file 100644
 
"""
 
   Firebird database specific implementations of changeset classes.
 
"""
 
from sqlalchemy.databases import firebird as sa_base
 

	
 
from migrate import exceptions
 
from migrate.changeset import ansisql, SQLA_06
 

	
 

	
 
if SQLA_06:
 
    FBSchemaGenerator = sa_base.FBDDLCompiler
 
else:
 
    FBSchemaGenerator = sa_base.FBSchemaGenerator
 

	
 
class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
 
    """Firebird column generator implementation."""
 

	
 

	
 
class FBColumnDropper(ansisql.ANSIColumnDropper):
 
    """Firebird column dropper implementation."""
 

	
 
    def visit_column(self, column):
 
        """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
 

	
 
        Drop primary key and unique constraints if dropped column is referencing it."""
 
        if column.primary_key:
 
            if column.table.primary_key.columns.contains_column(column):
 
                column.table.primary_key.drop()
 
                # TODO: recreate primary key if it references more than this column
 
        if column.unique or getattr(column, 'unique_name', None):
 
            for cons in column.table.constraints:
 
                if cons.contains_column(column):
 
                    cons.drop()
 
                    # TODO: recreate unique constraint if it refenrences more than this column
 

	
 
        table = self.start_alter_table(column)
 
        self.append('DROP %s' % self.preparer.format_column(column))
 
        self.execute()
 

	
 

	
 
class FBSchemaChanger(ansisql.ANSISchemaChanger):
 
    """Firebird schema changer implementation."""
 

	
 
    def visit_table(self, table):
 
        """Rename table not supported"""
 
        raise exceptions.NotSupportedError(
 
            "Firebird does not support renaming tables.")
 

	
 
    def _visit_column_name(self, table, column, delta):
 
        self.start_alter_table(table)
 
        col_name = self.preparer.quote(delta.current_name, table.quote)
 
        new_name = self.preparer.format_column(delta.result_column)
 
        self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
 

	
 
    def _visit_column_nullable(self, table, column, delta):
 
        """Changing NULL is not supported"""
 
        # TODO: http://www.firebirdfaq.org/faq103/
 
        raise exceptions.NotSupportedError(
 
            "Firebird does not support altering NULL bevahior.")
 

	
 

	
 
class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
 
    """Firebird constraint generator implementation."""
 

	
 

	
 
class FBConstraintDropper(ansisql.ANSIConstraintDropper):
 
    """Firebird constaint dropper implementation."""
 

	
 
    def cascade_constraint(self, constraint):
 
        """Cascading constraints is not supported"""
 
        raise exceptions.NotSupportedError(
 
            "Firebird does not support cascading constraints")
 

	
 

	
 
class FBDialect(ansisql.ANSIDialect):
 
    columngenerator = FBColumnGenerator
 
    columndropper = FBColumnDropper
 
    schemachanger = FBSchemaChanger
 
    constraintgenerator = FBConstraintGenerator
 
    constraintdropper = FBConstraintDropper
rhodecode/lib/dbmigrate/migrate/changeset/databases/mysql.py
Show inline comments
 
new file 100644
 
"""
 
   MySQL database specific implementations of changeset classes.
 
"""
 

	
 
from sqlalchemy.databases import mysql as sa_base
 
from sqlalchemy import types as sqltypes
 

	
 
from migrate import exceptions
 
from migrate.changeset import ansisql, SQLA_06
 

	
 

	
 
if not SQLA_06:
 
    MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
 
else:
 
    MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
 

	
 
class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
 
    pass
 

	
 

	
 
class MySQLColumnDropper(ansisql.ANSIColumnDropper):
 
    pass
 

	
 

	
 
class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
 

	
 
    def visit_column(self, delta):
 
        table = delta.table
 
        colspec = self.get_column_specification(delta.result_column)
 
        if delta.result_column.autoincrement:
 
            primary_keys = [c for c in table.primary_key.columns
 
                       if (c.autoincrement and
 
                            isinstance(c.type, sqltypes.Integer) and
 
                            not c.foreign_keys)]
 

	
 
            if primary_keys:
 
                first = primary_keys.pop(0)
 
                if first.name == delta.current_name:
 
                    colspec += " AUTO_INCREMENT"
 
        old_col_name = self.preparer.quote(delta.current_name, table.quote)
 

	
 
        self.start_alter_table(table)
 

	
 
        self.append("CHANGE COLUMN %s " % old_col_name)
 
        self.append(colspec)
 
        self.execute()
 

	
 
    def visit_index(self, param):
 
        # If MySQL can do this, I can't find how
 
        raise exceptions.NotSupportedError("MySQL cannot rename indexes")
 

	
 

	
 
class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
 
    pass
 

	
 
if SQLA_06:
 
    class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
 
        def visit_migrate_check_constraint(self, *p, **k):
 
            raise exceptions.NotSupportedError("MySQL does not support CHECK"
 
                " constraints, use triggers instead.")
 

	
 
else:
 
    class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
 

	
 
        def visit_migrate_primary_key_constraint(self, constraint):
 
            self.start_alter_table(constraint)
 
            self.append("DROP PRIMARY KEY")
 
            self.execute()
 

	
 
        def visit_migrate_foreign_key_constraint(self, constraint):
 
            self.start_alter_table(constraint)
 
            self.append("DROP FOREIGN KEY ")
 
            constraint.name = self.get_constraint_name(constraint)
 
            self.append(self.preparer.format_constraint(constraint))
 
            self.execute()
 

	
 
        def visit_migrate_check_constraint(self, *p, **k):
 
            raise exceptions.NotSupportedError("MySQL does not support CHECK"
 
                " constraints, use triggers instead.")
 

	
 
        def visit_migrate_unique_constraint(self, constraint, *p, **k):
 
            self.start_alter_table(constraint)
 
            self.append('DROP INDEX ')
 
            constraint.name = self.get_constraint_name(constraint)
 
            self.append(self.preparer.format_constraint(constraint))
 
            self.execute()
 

	
 

	
 
class MySQLDialect(ansisql.ANSIDialect):
 
    columngenerator = MySQLColumnGenerator
 
    columndropper = MySQLColumnDropper
 
    schemachanger = MySQLSchemaChanger
 
    constraintgenerator = MySQLConstraintGenerator
 
    constraintdropper = MySQLConstraintDropper
rhodecode/lib/dbmigrate/migrate/changeset/databases/oracle.py
Show inline comments
 
new file 100644
 
"""
 
   Oracle database specific implementations of changeset classes.
 
"""
 
import sqlalchemy as sa
 
from sqlalchemy.databases import oracle as sa_base
 

	
 
from migrate import exceptions
 
from migrate.changeset import ansisql, SQLA_06
 

	
 

	
 
if not SQLA_06:
 
    OracleSchemaGenerator = sa_base.OracleSchemaGenerator
 
else:
 
    OracleSchemaGenerator = sa_base.OracleDDLCompiler
 

	
 

	
 
class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
 
    pass
 

	
 

	
 
class OracleColumnDropper(ansisql.ANSIColumnDropper):
 
    pass
 

	
 

	
 
class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
 

	
 
    def get_column_specification(self, column, **kwargs):
 
        # Ignore the NOT NULL generated
 
        override_nullable = kwargs.pop('override_nullable', None)
 
        if override_nullable:
 
            orig = column.nullable
 
            column.nullable = True
 
        ret = super(OracleSchemaChanger, self).get_column_specification(
 
            column, **kwargs)
 
        if override_nullable:
 
            column.nullable = orig
 
        return ret
 

	
 
    def visit_column(self, delta):
 
        keys = delta.keys()
 

	
 
        if 'name' in keys:
 
            self._run_subvisit(delta,
 
                               self._visit_column_name,
 
                               start_alter=False)
 

	
 
        if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
 
            self._run_subvisit(delta,
 
                               self._visit_column_change,
 
                               start_alter=False)
 

	
 
    def _visit_column_change(self, table, column, delta):
 
        # Oracle cannot drop a default once created, but it can set it
 
        # to null.  We'll do that if default=None
 
        # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
 
        dropdefault_hack = (column.server_default is None \
 
                                and 'server_default' in delta.keys())
 
        # Oracle apparently doesn't like it when we say "not null" if
 
        # the column's already not null. Fudge it, so we don't need a
 
        # new function
 
        notnull_hack = ((not column.nullable) \
 
                            and ('nullable' not in delta.keys()))
 
        # We need to specify NULL if we're removing a NOT NULL
 
        # constraint
 
        null_hack = (column.nullable and ('nullable' in delta.keys()))
 

	
 
        if dropdefault_hack:
 
            column.server_default = sa.PassiveDefault(sa.sql.null())
 
        if notnull_hack:
 
            column.nullable = True
 
        colspec = self.get_column_specification(column,
 
            override_nullable=null_hack)
 
        if null_hack:
 
            colspec += ' NULL'
 
        if notnull_hack:
 
            column.nullable = False
 
        if dropdefault_hack:
 
            column.server_default = None
 

	
 
        self.start_alter_table(table)
 
        self.append("MODIFY (")
 
        self.append(colspec)
 
        self.append(")")
 

	
 

	
 
class OracleConstraintCommon(object):
 

	
 
    def get_constraint_name(self, cons):
 
        # Oracle constraints can't guess their name like other DBs
 
        if not cons.name:
 
            raise exceptions.NotSupportedError(
 
                "Oracle constraint names must be explicitly stated")
 
        return cons.name
 

	
 

	
 
class OracleConstraintGenerator(OracleConstraintCommon,
 
                                ansisql.ANSIConstraintGenerator):
 
    pass
 

	
 

	
 
class OracleConstraintDropper(OracleConstraintCommon,
 
                              ansisql.ANSIConstraintDropper):
 
    pass
 

	
 

	
 
class OracleDialect(ansisql.ANSIDialect):
 
    columngenerator = OracleColumnGenerator
 
    columndropper = OracleColumnDropper
 
    schemachanger = OracleSchemaChanger
 
    constraintgenerator = OracleConstraintGenerator
 
    constraintdropper = OracleConstraintDropper
rhodecode/lib/dbmigrate/migrate/changeset/databases/postgres.py
Show inline comments
 
new file 100644
 
"""
 
   `PostgreSQL`_ database specific implementations of changeset classes.
 

	
 
   .. _`PostgreSQL`: http://www.postgresql.org/
 
"""
 
from migrate.changeset import ansisql, SQLA_06
 

	
 
if not SQLA_06:
 
    from sqlalchemy.databases import postgres as sa_base
 
    PGSchemaGenerator = sa_base.PGSchemaGenerator
 
else:
 
    from sqlalchemy.databases import postgresql as sa_base
 
    PGSchemaGenerator = sa_base.PGDDLCompiler
 

	
 

	
 
class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
 
    """PostgreSQL column generator implementation."""
 
    pass
 

	
 

	
 
class PGColumnDropper(ansisql.ANSIColumnDropper):
 
    """PostgreSQL column dropper implementation."""
 
    pass
 

	
 

	
 
class PGSchemaChanger(ansisql.ANSISchemaChanger):
 
    """PostgreSQL schema changer implementation."""
 
    pass
 

	
 

	
 
class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
 
    """PostgreSQL constraint generator implementation."""
 
    pass
 

	
 

	
 
class PGConstraintDropper(ansisql.ANSIConstraintDropper):
 
    """PostgreSQL constaint dropper implementation."""
 
    pass
 

	
 

	
 
class PGDialect(ansisql.ANSIDialect):
 
    columngenerator = PGColumnGenerator
 
    columndropper = PGColumnDropper
 
    schemachanger = PGSchemaChanger
 
    constraintgenerator = PGConstraintGenerator
 
    constraintdropper = PGConstraintDropper
rhodecode/lib/dbmigrate/migrate/changeset/databases/sqlite.py
Show inline comments
 
new file 100644
 
"""
 
   `SQLite`_ database specific implementations of changeset classes.
 

	
 
   .. _`SQLite`: http://www.sqlite.org/
 
"""
 
from UserDict import DictMixin
 
from copy import copy
 

	
 
from sqlalchemy.databases import sqlite as sa_base
 

	
 
from migrate import exceptions
 
from migrate.changeset import ansisql, SQLA_06
 

	
 

	
 
if not SQLA_06:
 
    SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
 
else:
 
    SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
 

	
 
class SQLiteCommon(object):
 

	
 
    def _not_supported(self, op):
 
        raise exceptions.NotSupportedError("SQLite does not support "
 
            "%s; see http://www.sqlite.org/lang_altertable.html" % op)
 

	
 

	
 
class SQLiteHelper(SQLiteCommon):
 

	
 
    def recreate_table(self,table,column=None,delta=None):
 
        table_name = self.preparer.format_table(table)
 

	
 
        # we remove all indexes so as not to have
 
        # problems during copy and re-create
 
        for index in table.indexes:
 
            index.drop()
 

	
 
        self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
 
        self.execute()
 

	
 
        insertion_string = self._modify_table(table, column, delta)
 

	
 
        table.create()
 
        self.append(insertion_string % {'table_name': table_name})
 
        self.execute()
 
        self.append('DROP TABLE migration_tmp')
 
        self.execute()
 
        
 
    def visit_column(self, delta):
 
        if isinstance(delta, DictMixin):
 
            column = delta.result_column
 
            table = self._to_table(delta.table)
 
        else:
 
            column = delta
 
            table = self._to_table(column.table)
 
        self.recreate_table(table,column,delta)
 

	
 
class SQLiteColumnGenerator(SQLiteSchemaGenerator, 
 
                            ansisql.ANSIColumnGenerator,
 
                            # at the end so we get the normal
 
                            # visit_column by default
 
                            SQLiteHelper,
 
                            SQLiteCommon
 
                            ):
 
    """SQLite ColumnGenerator"""
 

	
 
    def _modify_table(self, table, column, delta):
 
        columns = ' ,'.join(map(
 
                self.preparer.format_column,
 
                [c for c in table.columns if c.name!=column.name]))
 
        return ('INSERT INTO %%(table_name)s (%(cols)s) '
 
                'SELECT %(cols)s from migration_tmp')%{'cols':columns}
 

	
 
    def visit_column(self,column):
 
        if column.foreign_keys:
 
            SQLiteHelper.visit_column(self,column)
 
        else:
 
            super(SQLiteColumnGenerator,self).visit_column(column)
 

	
 
class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
 
    """SQLite ColumnDropper"""
 

	
 
    def _modify_table(self, table, column, delta):
 
        columns = ' ,'.join(map(self.preparer.format_column, table.columns))
 
        return 'INSERT INTO %(table_name)s SELECT ' + columns + \
 
            ' from migration_tmp'
 

	
 

	
 
class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
 
    """SQLite SchemaChanger"""
 

	
 
    def _modify_table(self, table, column, delta):
 
        return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
 

	
 
    def visit_index(self, index):
 
        """Does not support ALTER INDEX"""
 
        self._not_supported('ALTER INDEX')
 

	
 

	
 
class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
 

	
 
    def visit_migrate_primary_key_constraint(self, constraint):
 
        tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
 
        cols = ', '.join(map(self.preparer.format_column, constraint.columns))
 
        tname = self.preparer.format_table(constraint.table)
 
        name = self.get_constraint_name(constraint)
 
        msg = tmpl % (name, tname, cols)
 
        self.append(msg)
 
        self.execute()
 

	
 
    def _modify_table(self, table, column, delta):
 
        return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
 

	
 
    def visit_migrate_foreign_key_constraint(self, *p, **k):
 
        self.recreate_table(p[0].table)
 

	
 
    def visit_migrate_unique_constraint(self, *p, **k):
 
        self.recreate_table(p[0].table)
 

	
 

	
 
class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
 
                              SQLiteCommon,
 
                              ansisql.ANSIConstraintCommon):
 

	
 
    def visit_migrate_primary_key_constraint(self, constraint):
 
        tmpl = "DROP INDEX %s "
 
        name = self.get_constraint_name(constraint)
 
        msg = tmpl % (name)
 
        self.append(msg)
 
        self.execute()
 

	
 
    def visit_migrate_foreign_key_constraint(self, *p, **k):
 
        self._not_supported('ALTER TABLE DROP CONSTRAINT')
 

	
 
    def visit_migrate_check_constraint(self, *p, **k):
 
        self._not_supported('ALTER TABLE DROP CONSTRAINT')
 

	
 
    def visit_migrate_unique_constraint(self, *p, **k):
 
        self._not_supported('ALTER TABLE DROP CONSTRAINT')
 

	
 

	
 
# TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
 

	
 
class SQLiteDialect(ansisql.ANSIDialect):
 
    columngenerator = SQLiteColumnGenerator
 
    columndropper = SQLiteColumnDropper
 
    schemachanger = SQLiteSchemaChanger
 
    constraintgenerator = SQLiteConstraintGenerator
 
    constraintdropper = SQLiteConstraintDropper
rhodecode/lib/dbmigrate/migrate/changeset/databases/visitor.py
Show inline comments
 
new file 100644
 
"""
 
   Module for visitor class mapping.
 
"""
 
import sqlalchemy as sa
 

	
 
from migrate.changeset import ansisql
 
from migrate.changeset.databases import (sqlite,
 
                                         postgres,
 
                                         mysql,
 
                                         oracle,
 
                                         firebird)
 

	
 

	
 
# Map SA dialects to the corresponding Migrate extensions
 
DIALECTS = {
 
    "default": ansisql.ANSIDialect,
 
    "sqlite": sqlite.SQLiteDialect,
 
    "postgres": postgres.PGDialect,
 
    "postgresql": postgres.PGDialect,
 
    "mysql": mysql.MySQLDialect,
 
    "oracle": oracle.OracleDialect,
 
    "firebird": firebird.FBDialect,
 
}
 

	
 

	
 
def get_engine_visitor(engine, name):
 
    """
 
    Get the visitor implementation for the given database engine.
 

	
 
    :param engine: SQLAlchemy Engine
 
    :param name: Name of the visitor
 
    :type name: string
 
    :type engine: Engine
 
    :returns: visitor
 
    """
 
    # TODO: link to supported visitors
 
    return get_dialect_visitor(engine.dialect, name)
 

	
 

	
 
def get_dialect_visitor(sa_dialect, name):
 
    """
 
    Get the visitor implementation for the given dialect.
 

	
 
    Finds the visitor implementation based on the dialect class and
 
    returns and instance initialized with the given name.
 

	
 
    Binds dialect specific preparer to visitor.
 
    """
 

	
 
    # map sa dialect to migrate dialect and return visitor
 
    sa_dialect_name = getattr(sa_dialect, 'name', 'default')
 
    migrate_dialect_cls = DIALECTS[sa_dialect_name]
 
    visitor = getattr(migrate_dialect_cls, name)
 

	
 
    # bind preparer
 
    visitor.preparer = sa_dialect.preparer(sa_dialect)
 

	
 
    return visitor
 

	
 
def run_single_visitor(engine, visitorcallable, element,
 
    connection=None, **kwargs):
 
    """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
 
    with support for migrate visitors.
 
    """
 
    if connection is None:
 
        conn = engine.contextual_connect(close_with_result=False)
 
    else:
 
        conn = connection
 
    visitor = visitorcallable(engine.dialect, conn)
 
    try:
 
        if hasattr(element, '__migrate_visit_name__'):
 
            fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
 
        else:
 
            fn = getattr(visitor, 'visit_' + element.__visit_name__)
 
        fn(element, **kwargs)
 
    finally:
 
        if connection is None:
 
            conn.close()
rhodecode/lib/dbmigrate/migrate/changeset/schema.py
Show inline comments
 
new file 100644
 
"""
 
   Schema module providing common schema operations.
 
"""
 
import warnings
 

	
 
from UserDict import DictMixin
 

	
 
import sqlalchemy
 

	
 
from sqlalchemy.schema import ForeignKeyConstraint
 
from sqlalchemy.schema import UniqueConstraint
 

	
 
from migrate.exceptions import *
 
from migrate.changeset import SQLA_06
 
from migrate.changeset.databases.visitor import (get_engine_visitor,
 
                                                 run_single_visitor)
 

	
 

	
 
__all__ = [
 
    'create_column',
 
    'drop_column',
 
    'alter_column',
 
    'rename_table',
 
    'rename_index',
 
    'ChangesetTable',
 
    'ChangesetColumn',
 
    'ChangesetIndex',
 
    'ChangesetDefaultClause',
 
    'ColumnDelta',
 
]
 

	
 
DEFAULT_ALTER_METADATA = True
 

	
 

	
 
def create_column(column, table=None, *p, **kw):
 
    """Create a column, given the table.
 
    
 
    API to :meth:`ChangesetColumn.create`.
 
    """
 
    if table is not None:
 
        return table.create_column(column, *p, **kw)
 
    return column.create(*p, **kw)
 

	
 

	
 
def drop_column(column, table=None, *p, **kw):
 
    """Drop a column, given the table.
 
    
 
    API to :meth:`ChangesetColumn.drop`.
 
    """
 
    if table is not None:
 
        return table.drop_column(column, *p, **kw)
 
    return column.drop(*p, **kw)
 

	
 

	
 
def rename_table(table, name, engine=None, **kw):
 
    """Rename a table.
 

	
 
    If Table instance is given, engine is not used.
 

	
 
    API to :meth:`ChangesetTable.rename`.
 

	
 
    :param table: Table to be renamed.
 
    :param name: New name for Table.
 
    :param engine: Engine instance.
 
    :type table: string or Table instance
 
    :type name: string
 
    :type engine: obj
 
    """
 
    table = _to_table(table, engine)
 
    table.rename(name, **kw)
 

	
 

	
 
def rename_index(index, name, table=None, engine=None, **kw):
 
    """Rename an index.
 

	
 
    If Index instance is given,
 
    table and engine are not used.
 

	
 
    API to :meth:`ChangesetIndex.rename`.
 

	
 
    :param index: Index to be renamed.
 
    :param name: New name for index.
 
    :param table: Table to which Index is reffered.
 
    :param engine: Engine instance.
 
    :type index: string or Index instance
 
    :type name: string
 
    :type table: string or Table instance
 
    :type engine: obj
 
    """
 
    index = _to_index(index, table, engine)
 
    index.rename(name, **kw)
 

	
 

	
 
def alter_column(*p, **k):
 
    """Alter a column.
 

	
 
    This is a helper function that creates a :class:`ColumnDelta` and
 
    runs it.
 

	
 
    :argument column:
 
      The name of the column to be altered or a
 
      :class:`ChangesetColumn` column representing it.
 

	
 
    :param table:
 
      A :class:`~sqlalchemy.schema.Table` or table name to
 
      for the table where the column will be changed.
 

	
 
    :param engine:
 
      The :class:`~sqlalchemy.engine.base.Engine` to use for table
 
      reflection and schema alterations.
 
    
 
    :param alter_metadata:
 
      If `True`, which is the default, the
 
      :class:`~sqlalchemy.schema.Column` will also modified.
 
      If `False`, the :class:`~sqlalchemy.schema.Column` will be left
 
      as it was.
 
    
 
    :returns: A :class:`ColumnDelta` instance representing the change.
 

	
 
    
 
    """
 
    
 
    k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA)
 

	
 
    if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
 
        k['table'] = p[0].table
 
    if 'engine' not in k:
 
        k['engine'] = k['table'].bind
 

	
 
    # deprecation
 
    if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
 
        warnings.warn(
 
            "Passing a Column object to alter_column is deprecated."
 
            " Just pass in keyword parameters instead.",
 
            MigrateDeprecationWarning
 
            )
 
    engine = k['engine']
 
    delta = ColumnDelta(*p, **k)
 

	
 
    visitorcallable = get_engine_visitor(engine, 'schemachanger')
 
    engine._run_visitor(visitorcallable, delta)
 

	
 
    return delta
 

	
 

	
 
def _to_table(table, engine=None):
 
    """Return if instance of Table, else construct new with metadata"""
 
    if isinstance(table, sqlalchemy.Table):
 
        return table
 

	
 
    # Given: table name, maybe an engine
 
    meta = sqlalchemy.MetaData()
 
    if engine is not None:
 
        meta.bind = engine
 
    return sqlalchemy.Table(table, meta)
 

	
 

	
 
def _to_index(index, table=None, engine=None):
 
    """Return if instance of Index, else construct new with metadata"""
 
    if isinstance(index, sqlalchemy.Index):
 
        return index
 

	
 
    # Given: index name; table name required
 
    table = _to_table(table, engine)
 
    ret = sqlalchemy.Index(index)
 
    ret.table = table
 
    return ret
 

	
 

	
 
class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
 
    """Extracts the differences between two columns/column-parameters
 

	
 
        May receive parameters arranged in several different ways:
 

	
 
        * **current_column, new_column, \*p, \*\*kw**
 
            Additional parameters can be specified to override column
 
            differences.
 

	
 
        * **current_column, \*p, \*\*kw**
 
            Additional parameters alter current_column. Table name is extracted
 
            from current_column object.
 
            Name is changed to current_column.name from current_name,
 
            if current_name is specified.
 

	
 
        * **current_col_name, \*p, \*\*kw**
 
            Table kw must specified.
 

	
 
        :param table: Table at which current Column should be bound to.\
 
        If table name is given, reflection will be used.
 
        :type table: string or Table instance
 
        :param alter_metadata: If True, it will apply changes to metadata.
 
        :type alter_metadata: bool
 
        :param metadata: If `alter_metadata` is true, \
 
        metadata is used to reflect table names into
 
        :type metadata: :class:`MetaData` instance
 
        :param engine: When reflecting tables, either engine or metadata must \
 
        be specified to acquire engine object.
 
        :type engine: :class:`Engine` instance
 
        :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
 
        `result_column` through :func:`dict` alike object.
 

	
 
        * :class:`ColumnDelta`.result_column is altered column with new attributes
 

	
 
        * :class:`ColumnDelta`.current_name is current name of column in db
 

	
 

	
 
    """
 

	
 
    # Column attributes that can be altered
 
    diff_keys = ('name', 'type', 'primary_key', 'nullable',
 
        'server_onupdate', 'server_default', 'autoincrement')
 
    diffs = dict()
 
    __visit_name__ = 'column'
 

	
 
    def __init__(self, *p, **kw):
 
        self.alter_metadata = kw.pop("alter_metadata", False)
 
        self.meta = kw.pop("metadata", None)
 
        self.engine = kw.pop("engine", None)
 

	
 
        # Things are initialized differently depending on how many column
 
        # parameters are given. Figure out how many and call the appropriate
 
        # method.
 
        if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
 
            # At least one column specified
 
            if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
 
                # Two columns specified
 
                diffs = self.compare_2_columns(*p, **kw)
 
            else:
 
                # Exactly one column specified
 
                diffs = self.compare_1_column(*p, **kw)
 
        else:
 
            # Zero columns specified
 
            if not len(p) or not isinstance(p[0], basestring):
 
                raise ValueError("First argument must be column name")
 
            diffs = self.compare_parameters(*p, **kw)
 

	
 
        self.apply_diffs(diffs)
 

	
 
    def __repr__(self):
 
        return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata,
 
            super(ColumnDelta, self).__repr__())
 

	
 
    def __getitem__(self, key):
 
        if key not in self.keys():
 
            raise KeyError("No such diff key, available: %s" % self.diffs)
 
        return getattr(self.result_column, key)
 

	
 
    def __setitem__(self, key, value):
 
        if key not in self.keys():
 
            raise KeyError("No such diff key, available: %s" % self.diffs)
 
        setattr(self.result_column, key, value)
 

	
 
    def __delitem__(self, key):
 
        raise NotImplementedError
 

	
 
    def keys(self):
 
        return self.diffs.keys()
 

	
 
    def compare_parameters(self, current_name, *p, **k):
 
        """Compares Column objects with reflection"""
 
        self.table = k.pop('table')
 
        self.result_column = self._table.c.get(current_name)
 
        if len(p):
 
            k = self._extract_parameters(p, k, self.result_column)
 
        return k
 

	
 
    def compare_1_column(self, col, *p, **k):
 
        """Compares one Column object"""
 
        self.table = k.pop('table', None)
 
        if self.table is None:
 
            self.table = col.table
 
        self.result_column = col
 
        if len(p):
 
            k = self._extract_parameters(p, k, self.result_column)
 
        return k
 

	
 
    def compare_2_columns(self, old_col, new_col, *p, **k):
 
        """Compares two Column objects"""
 
        self.process_column(new_col)
 
        self.table = k.pop('table', None)
 
        # we cannot use bool() on table in SA06 
 
        if self.table is None:
 
            self.table = old_col.table
 
        if self.table is None:
 
            new_col.table
 
        self.result_column = old_col
 

	
 
        # set differences
 
        # leave out some stuff for later comp
 
        for key in (set(self.diff_keys) - set(('type',))):
 
            val = getattr(new_col, key, None)
 
            if getattr(self.result_column, key, None) != val:
 
                k.setdefault(key, val)
 

	
 
        # inspect types
 
        if not self.are_column_types_eq(self.result_column.type, new_col.type):
 
            k.setdefault('type', new_col.type)
 

	
 
        if len(p):
 
            k = self._extract_parameters(p, k, self.result_column)
 
        return k
 

	
 
    def apply_diffs(self, diffs):
 
        """Populate dict and column object with new values"""
 
        self.diffs = diffs
 
        for key in self.diff_keys:
 
            if key in diffs:
 
                setattr(self.result_column, key, diffs[key])
 

	
 
        self.process_column(self.result_column)
 

	
 
        # create an instance of class type if not yet
 
        if 'type' in diffs and callable(self.result_column.type):
 
            self.result_column.type = self.result_column.type()
 

	
 
        # add column to the table
 
        if self.table is not None and self.alter_metadata:
 
            self.result_column.add_to_table(self.table)
 

	
 
    def are_column_types_eq(self, old_type, new_type):
 
        """Compares two types to be equal"""
 
        ret = old_type.__class__ == new_type.__class__
 

	
 
        # String length is a special case
 
        if ret and isinstance(new_type, sqlalchemy.types.String):
 
            ret = (getattr(old_type, 'length', None) == \
 
                       getattr(new_type, 'length', None))
 
        return ret
 

	
 
    def _extract_parameters(self, p, k, column):
 
        """Extracts data from p and modifies diffs"""
 
        p = list(p)
 
        while len(p):
 
            if isinstance(p[0], basestring):
 
                k.setdefault('name', p.pop(0))
 
            elif isinstance(p[0], sqlalchemy.types.AbstractType):
 
                k.setdefault('type', p.pop(0))
 
            elif callable(p[0]):
 
                p[0] = p[0]()
 
            else:
 
                break
 

	
 
        if len(p):
 
            new_col = column.copy_fixed()
 
            new_col._init_items(*p)
 
            k = self.compare_2_columns(column, new_col, **k)
 
        return k
 

	
 
    def process_column(self, column):
 
        """Processes default values for column"""
 
        # XXX: this is a snippet from SA processing of positional parameters
 
        if not SQLA_06 and column.args:
 
            toinit = list(column.args)
 
        else:
 
            toinit = list()
 

	
 
        if column.server_default is not None:
 
            if isinstance(column.server_default, sqlalchemy.FetchedValue):
 
                toinit.append(column.server_default)
 
            else:
 
                toinit.append(sqlalchemy.DefaultClause(column.server_default))
 
        if column.server_onupdate is not None:
 
            if isinstance(column.server_onupdate, FetchedValue):
 
                toinit.append(column.server_default)
 
            else:
 
                toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
 
                                            for_update=True))
 
        if toinit:
 
            column._init_items(*toinit)
 
            
 
        if not SQLA_06:
 
            column.args = []
 

	
 
    def _get_table(self):
 
        return getattr(self, '_table', None)
 

	
 
    def _set_table(self, table):
 
        if isinstance(table, basestring):
 
            if self.alter_metadata:
 
                if not self.meta:
 
                    raise ValueError("metadata must be specified for table"
 
                        " reflection when using alter_metadata")
 
                meta = self.meta
 
                if self.engine:
 
                    meta.bind = self.engine
 
            else:
 
                if not self.engine and not self.meta:
 
                    raise ValueError("engine or metadata must be specified"
 
                        " to reflect tables")
 
                if not self.engine:
 
                    self.engine = self.meta.bind
 
                meta = sqlalchemy.MetaData(bind=self.engine)
 
            self._table = sqlalchemy.Table(table, meta, autoload=True)
 
        elif isinstance(table, sqlalchemy.Table):
 
            self._table = table
 
            if not self.alter_metadata:
 
                self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
 

	
 
    def _get_result_column(self):
 
        return getattr(self, '_result_column', None)
 

	
 
    def _set_result_column(self, column):
 
        """Set Column to Table based on alter_metadata evaluation."""
 
        self.process_column(column)
 
        if not hasattr(self, 'current_name'):
 
            self.current_name = column.name
 
        if self.alter_metadata:
 
            self._result_column = column
 
        else:
 
            self._result_column = column.copy_fixed()
 

	
 
    table = property(_get_table, _set_table)
 
    result_column = property(_get_result_column, _set_result_column)
 

	
 

	
 
class ChangesetTable(object):
 
    """Changeset extensions to SQLAlchemy tables."""
 

	
 
    def create_column(self, column, *p, **kw):
 
        """Creates a column.
 

	
 
        The column parameter may be a column definition or the name of
 
        a column in this table.
 

	
 
        API to :meth:`ChangesetColumn.create`
 

	
 
        :param column: Column to be created
 
        :type column: Column instance or string
 
        """
 
        if not isinstance(column, sqlalchemy.Column):
 
            # It's a column name
 
            column = getattr(self.c, str(column))
 
        column.create(table=self, *p, **kw)
 

	
 
    def drop_column(self, column, *p, **kw):
 
        """Drop a column, given its name or definition.
 

	
 
        API to :meth:`ChangesetColumn.drop`
 

	
 
        :param column: Column to be droped
 
        :type column: Column instance or string
 
        """
 
        if not isinstance(column, sqlalchemy.Column):
 
            # It's a column name
 
            try:
 
                column = getattr(self.c, str(column))
 
            except AttributeError:
 
                # That column isn't part of the table. We don't need
 
                # its entire definition to drop the column, just its
 
                # name, so create a dummy column with the same name.
 
                column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
 
        column.drop(table=self, *p, **kw)
 

	
 
    def rename(self, name, connection=None, **kwargs):
 
        """Rename this table.
 

	
 
        :param name: New name of the table.
 
        :type name: string
 
        :param alter_metadata: If True, table will be removed from metadata
 
        :type alter_metadata: bool
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
 
        engine = self.bind
 
        self.new_name = name
 
        visitorcallable = get_engine_visitor(engine, 'schemachanger')
 
        run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
 

	
 
        # Fix metadata registration
 
        if self.alter_metadata:
 
            self.name = name
 
            self.deregister()
 
            self._set_parent(self.metadata)
 

	
 
    def _meta_key(self):
 
        return sqlalchemy.schema._get_table_key(self.name, self.schema)
 

	
 
    def deregister(self):
 
        """Remove this table from its metadata"""
 
        key = self._meta_key()
 
        meta = self.metadata
 
        if key in meta.tables:
 
            del meta.tables[key]
 

	
 

	
 
class ChangesetColumn(object):
 
    """Changeset extensions to SQLAlchemy columns."""
 

	
 
    def alter(self, *p, **k):
 
        """Makes a call to :func:`alter_column` for the column this
 
        method is called on. 
 
        """
 
        if 'table' not in k:
 
            k['table'] = self.table
 
        if 'engine' not in k:
 
            k['engine'] = k['table'].bind
 
        return alter_column(self, *p, **k)
 

	
 
    def create(self, table=None, index_name=None, unique_name=None,
 
               primary_key_name=None, populate_default=True, connection=None, **kwargs):
 
        """Create this column in the database.
 

	
 
        Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
 
        for most databases.
 

	
 
        :param table: Table instance to create on.
 
        :param index_name: Creates :class:`ChangesetIndex` on this column.
 
        :param unique_name: Creates :class:\
 
`~migrate.changeset.constraint.UniqueConstraint` on this column.
 
        :param primary_key_name: Creates :class:\
 
`~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
 
        :param alter_metadata: If True, column will be added to table object.
 
        :param populate_default: If True, created column will be \
 
populated with defaults
 
        :param connection: reuse connection istead of creating new one.
 
        :type table: Table instance
 
        :type index_name: string
 
        :type unique_name: string
 
        :type primary_key_name: string
 
        :type alter_metadata: bool
 
        :type populate_default: bool
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 

	
 
        :returns: self
 
        """
 
        self.populate_default = populate_default
 
        self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
 
        self.index_name = index_name
 
        self.unique_name = unique_name
 
        self.primary_key_name = primary_key_name
 
        for cons in ('index_name', 'unique_name', 'primary_key_name'):
 
            self._check_sanity_constraints(cons)
 

	
 
        if self.alter_metadata:
 
            self.add_to_table(table)
 
        engine = self.table.bind
 
        visitorcallable = get_engine_visitor(engine, 'columngenerator')
 
        engine._run_visitor(visitorcallable, self, connection, **kwargs)
 

	
 
        # TODO: reuse existing connection
 
        if self.populate_default and self.default is not None:
 
            stmt = table.update().values({self: engine._execute_default(self.default)})
 
            engine.execute(stmt)
 

	
 
        return self
 

	
 
    def drop(self, table=None, connection=None, **kwargs):
 
        """Drop this column from the database, leaving its table intact.
 

	
 
        ``ALTER TABLE DROP COLUMN``, for most databases.
 

	
 
        :param alter_metadata: If True, column will be removed from table object.
 
        :type alter_metadata: bool
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
 
        if table is not None:
 
            self.table = table
 
        engine = self.table.bind
 
        if self.alter_metadata:
 
            self.remove_from_table(self.table, unset_table=False)
 
        visitorcallable = get_engine_visitor(engine, 'columndropper')
 
        engine._run_visitor(visitorcallable, self, connection, **kwargs)
 
        if self.alter_metadata:
 
            self.table = None
 
        return self
 

	
 
    def add_to_table(self, table):
 
        if table is not None  and self.table is None:
 
            self._set_parent(table)
 

	
 
    def _col_name_in_constraint(self, cons, name):
 
        return False
 
    
 
    def remove_from_table(self, table, unset_table=True):
 
        # TODO: remove primary keys, constraints, etc
 
        if unset_table:
 
            self.table = None
 
            
 
        to_drop = set()
 
        for index in table.indexes:
 
            columns = []
 
            for col in index.columns:
 
                if col.name != self.name:
 
                    columns.append(col)
 
            if columns:
 
                index.columns = columns
 
            else:
 
                to_drop.add(index)
 
        table.indexes = table.indexes - to_drop
 
        
 
        to_drop = set()
 
        for cons in table.constraints:
 
            # TODO: deal with other types of constraint
 
            if isinstance(cons, (ForeignKeyConstraint,
 
                                UniqueConstraint)):
 
                for col_name in cons.columns:
 
                    if not isinstance(col_name, basestring):
 
                        col_name = col_name.name
 
                    if self.name == col_name:
 
                        to_drop.add(cons)
 
        table.constraints = table.constraints - to_drop
 
        
 
        if table.c.contains_column(self):
 
            table.c.remove(self)
 

	
 
    # TODO: this is fixed in 0.6
 
    def copy_fixed(self, **kw):
 
        """Create a copy of this ``Column``, with all attributes."""
 
        return sqlalchemy.Column(self.name, self.type, self.default,
 
            key=self.key,
 
            primary_key=self.primary_key,
 
            nullable=self.nullable,
 
            quote=self.quote,
 
            index=self.index,
 
            unique=self.unique,
 
            onupdate=self.onupdate,
 
            autoincrement=self.autoincrement,
 
            server_default=self.server_default,
 
            server_onupdate=self.server_onupdate,
 
            *[c.copy(**kw) for c in self.constraints])
 

	
 
    def _check_sanity_constraints(self, name):
 
        """Check if constraints names are correct"""
 
        obj = getattr(self, name)
 
        if (getattr(self, name[:-5]) and not obj):
 
            raise InvalidConstraintError("Column.create() accepts index_name,"
 
            " primary_key_name and unique_name to generate constraints")
 
        if not isinstance(obj, basestring) and obj is not None:
 
            raise InvalidConstraintError(
 
            "%s argument for column must be constraint name" % name)
 

	
 

	
 
class ChangesetIndex(object):
 
    """Changeset extensions to SQLAlchemy Indexes."""
 

	
 
    __visit_name__ = 'index'
 

	
 
    def rename(self, name, connection=None, **kwargs):
 
        """Change the name of an index.
 

	
 
        :param name: New name of the Index.
 
        :type name: string
 
        :param alter_metadata: If True, Index object will be altered.
 
        :type alter_metadata: bool
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
 
        engine = self.table.bind
 
        self.new_name = name
 
        visitorcallable = get_engine_visitor(engine, 'schemachanger')
 
        engine._run_visitor(visitorcallable, self, connection, **kwargs)
 
        if self.alter_metadata:
 
            self.name = name
 

	
 

	
 
class ChangesetDefaultClause(object):
 
    """Implements comparison between :class:`DefaultClause` instances"""
 

	
 
    def __eq__(self, other):
 
        if isinstance(other, self.__class__):
 
            if self.arg == other.arg:
 
                return True
 

	
 
    def __ne__(self, other):
 
        return not self.__eq__(other)
rhodecode/lib/dbmigrate/migrate/exceptions.py
Show inline comments
 
new file 100644
 
"""
 
   Provide exception classes for :mod:`migrate`
 
"""
 

	
 

	
 
class Error(Exception):
 
    """Error base class."""
 

	
 

	
 
class ApiError(Error):
 
    """Base class for API errors."""
 

	
 

	
 
class KnownError(ApiError):
 
    """A known error condition."""
 

	
 

	
 
class UsageError(ApiError):
 
    """A known error condition where help should be displayed."""
 

	
 

	
 
class ControlledSchemaError(Error):
 
    """Base class for controlled schema errors."""
 

	
 

	
 
class InvalidVersionError(ControlledSchemaError):
 
    """Invalid version number."""
 

	
 

	
 
class DatabaseNotControlledError(ControlledSchemaError):
 
    """Database should be under version control, but it's not."""
 

	
 

	
 
class DatabaseAlreadyControlledError(ControlledSchemaError):
 
    """Database shouldn't be under version control, but it is"""
 

	
 

	
 
class WrongRepositoryError(ControlledSchemaError):
 
    """This database is under version control by another repository."""
 

	
 

	
 
class NoSuchTableError(ControlledSchemaError):
 
    """The table does not exist."""
 

	
 

	
 
class PathError(Error):
 
    """Base class for path errors."""
 

	
 

	
 
class PathNotFoundError(PathError):
 
    """A path with no file was required; found a file."""
 

	
 

	
 
class PathFoundError(PathError):
 
    """A path with a file was required; found no file."""
 

	
 

	
 
class RepositoryError(Error):
 
    """Base class for repository errors."""
 

	
 

	
 
class InvalidRepositoryError(RepositoryError):
 
    """Invalid repository error."""
 

	
 

	
 
class ScriptError(Error):
 
    """Base class for script errors."""
 

	
 

	
 
class InvalidScriptError(ScriptError):
 
    """Invalid script error."""
 

	
 

	
 
class InvalidVersionError(Error):
 
    """Invalid version error."""
 

	
 
# migrate.changeset
 

	
 
class NotSupportedError(Error):
 
    """Not supported error"""
 

	
 

	
 
class InvalidConstraintError(Error):
 
    """Invalid constraint error"""
 

	
 
class MigrateDeprecationWarning(DeprecationWarning):
 
    """Warning for deprecated features in Migrate"""
rhodecode/lib/dbmigrate/migrate/versioning/__init__.py
Show inline comments
 
new file 100644
 
"""
 
   This package provides functionality to create and manage
 
   repositories of database schema changesets and to apply these
 
   changesets to databases.
 
"""
rhodecode/lib/dbmigrate/migrate/versioning/api.py
Show inline comments
 
new file 100644
 
"""
 
   This module provides an external API to the versioning system.
 

	
 
   .. versionchanged:: 0.6.0
 
    :func:`migrate.versioning.api.test` and schema diff functions
 
    changed order of positional arguments so all accept `url` and `repository`
 
    as first arguments.
 

	
 
   .. versionchanged:: 0.5.4
 
    ``--preview_sql`` displays source file when using SQL scripts.
 
    If Python script is used, it runs the action with mocked engine and
 
    returns captured SQL statements.
 

	
 
   .. versionchanged:: 0.5.4
 
    Deprecated ``--echo`` parameter in favour of new
 
    :func:`migrate.versioning.util.construct_engine` behavior.
 
"""
 

	
 
# Dear migrate developers,
 
#
 
# please do not comment this module using sphinx syntax because its
 
# docstrings are presented as user help and most users cannot
 
# interpret sphinx annotated ReStructuredText.
 
#
 
# Thanks,
 
# Jan Dittberner
 

	
 
import sys
 
import inspect
 
import logging
 

	
 
from migrate import exceptions
 
from migrate.versioning import (repository, schema, version,
 
    script as script_) # command name conflict
 
from migrate.versioning.util import catch_known_errors, with_engine
 

	
 

	
 
log = logging.getLogger(__name__)
 
command_desc = {
 
    'help': 'displays help on a given command',
 
    'create': 'create an empty repository at the specified path',
 
    'script': 'create an empty change Python script',
 
    'script_sql': 'create empty change SQL scripts for given database',
 
    'version': 'display the latest version available in a repository',
 
    'db_version': 'show the current version of the repository under version control',
 
    'source': 'display the Python code for a particular version in this repository',
 
    'version_control': 'mark a database as under this repository\'s version control',
 
    'upgrade': 'upgrade a database to a later version',
 
    'downgrade': 'downgrade a database to an earlier version',
 
    'drop_version_control': 'removes version control from a database',
 
    'manage': 'creates a Python script that runs Migrate with a set of default values',
 
    'test': 'performs the upgrade and downgrade command on the given database',
 
    'compare_model_to_db': 'compare MetaData against the current database state',
 
    'create_model': 'dump the current database as a Python model to stdout',
 
    'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
 
    'update_db_from_model': 'modify the database to match the structure of the current MetaData',
 
}
 
__all__ = command_desc.keys()
 

	
 
Repository = repository.Repository
 
ControlledSchema = schema.ControlledSchema
 
VerNum = version.VerNum
 
PythonScript = script_.PythonScript
 
SqlScript = script_.SqlScript
 

	
 

	
 
# deprecated
 
def help(cmd=None, **opts):
 
    """%prog help COMMAND
 

	
 
    Displays help on a given command.
 
    """
 
    if cmd is None:
 
        raise exceptions.UsageError(None)
 
    try:
 
        func = globals()[cmd]
 
    except:
 
        raise exceptions.UsageError(
 
            "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
 
    ret = func.__doc__
 
    if sys.argv[0]:
 
        ret = ret.replace('%prog', sys.argv[0])
 
    return ret
 

	
 
@catch_known_errors
 
def create(repository, name, **opts):
 
    """%prog create REPOSITORY_PATH NAME [--table=TABLE]
 

	
 
    Create an empty repository at the specified path.
 

	
 
    You can specify the version_table to be used; by default, it is
 
    'migrate_version'.  This table is created in all version-controlled
 
    databases.
 
    """
 
    repo_path = Repository.create(repository, name, **opts)
 

	
 

	
 
@catch_known_errors
 
def script(description, repository, **opts):
 
    """%prog script DESCRIPTION REPOSITORY_PATH
 

	
 
    Create an empty change script using the next unused version number
 
    appended with the given description.
 

	
 
    For instance, manage.py script "Add initial tables" creates:
 
    repository/versions/001_Add_initial_tables.py
 
    """
 
    repo = Repository(repository)
 
    repo.create_script(description, **opts)
 

	
 

	
 
@catch_known_errors
 
def script_sql(database, repository, **opts):
 
    """%prog script_sql DATABASE REPOSITORY_PATH
 

	
 
    Create empty change SQL scripts for given DATABASE, where DATABASE
 
    is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
 
    or generic ('default').
 

	
 
    For instance, manage.py script_sql postgres creates:
 
    repository/versions/001_postgres_upgrade.sql and
 
    repository/versions/001_postgres_postgres.sql
 
    """
 
    repo = Repository(repository)
 
    repo.create_script_sql(database, **opts)
 

	
 

	
 
def version(repository, **opts):
 
    """%prog version REPOSITORY_PATH
 

	
 
    Display the latest version available in a repository.
 
    """
 
    repo = Repository(repository)
 
    return repo.latest
 

	
 

	
 
@with_engine
 
def db_version(url, repository, **opts):
 
    """%prog db_version URL REPOSITORY_PATH
 

	
 
    Show the current version of the repository with the given
 
    connection string, under version control of the specified
 
    repository.
 

	
 
    The url should be any valid SQLAlchemy connection string.
 
    """
 
    engine = opts.pop('engine')
 
    schema = ControlledSchema(engine, repository)
 
    return schema.version
 

	
 

	
 
def source(version, dest=None, repository=None, **opts):
 
    """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
 

	
 
    Display the Python code for a particular version in this
 
    repository.  Save it to the file at DESTINATION or, if omitted,
 
    send to stdout.
 
    """
 
    if repository is None:
 
        raise exceptions.UsageError("A repository must be specified")
 
    repo = Repository(repository)
 
    ret = repo.version(version).script().source()
 
    if dest is not None:
 
        dest = open(dest, 'w')
 
        dest.write(ret)
 
        dest.close()
 
        ret = None
 
    return ret
 

	
 

	
 
def upgrade(url, repository, version=None, **opts):
 
    """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
 

	
 
    Upgrade a database to a later version.
 

	
 
    This runs the upgrade() function defined in your change scripts.
 

	
 
    By default, the database is updated to the latest available
 
    version. You may specify a version instead, if you wish.
 

	
 
    You may preview the Python or SQL code to be executed, rather than
 
    actually executing it, using the appropriate 'preview' option.
 
    """
 
    err = "Cannot upgrade a database of version %s to version %s. "\
 
        "Try 'downgrade' instead."
 
    return _migrate(url, repository, version, upgrade=True, err=err, **opts)
 

	
 

	
 
def downgrade(url, repository, version, **opts):
 
    """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
 

	
 
    Downgrade a database to an earlier version.
 

	
 
    This is the reverse of upgrade; this runs the downgrade() function
 
    defined in your change scripts.
 

	
 
    You may preview the Python or SQL code to be executed, rather than
 
    actually executing it, using the appropriate 'preview' option.
 
    """
 
    err = "Cannot downgrade a database of version %s to version %s. "\
 
        "Try 'upgrade' instead."
 
    return _migrate(url, repository, version, upgrade=False, err=err, **opts)
 

	
 
@with_engine
 
def test(url, repository, **opts):
 
    """%prog test URL REPOSITORY_PATH [VERSION]
 

	
 
    Performs the upgrade and downgrade option on the given
 
    database. This is not a real test and may leave the database in a
 
    bad state. You should therefore better run the test on a copy of
 
    your database.
 
    """
 
    engine = opts.pop('engine')
 
    repos = Repository(repository)
 
    script = repos.version(None).script()
 

	
 
    # Upgrade
 
    log.info("Upgrading...")
 
    script.run(engine, 1)
 
    log.info("done")
 

	
 
    log.info("Downgrading...")
 
    script.run(engine, -1)
 
    log.info("done")
 
    log.info("Success")
 

	
 

	
 
@with_engine
 
def version_control(url, repository, version=None, **opts):
 
    """%prog version_control URL REPOSITORY_PATH [VERSION]
 

	
 
    Mark a database as under this repository's version control.
 

	
 
    Once a database is under version control, schema changes should
 
    only be done via change scripts in this repository.
 

	
 
    This creates the table version_table in the database.
 

	
 
    The url should be any valid SQLAlchemy connection string.
 

	
 
    By default, the database begins at version 0 and is assumed to be
 
    empty.  If the database is not empty, you may specify a version at
 
    which to begin instead. No attempt is made to verify this
 
    version's correctness - the database schema is expected to be
 
    identical to what it would be if the database were created from
 
    scratch.
 
    """
 
    engine = opts.pop('engine')
 
    ControlledSchema.create(engine, repository, version)
 

	
 

	
 
@with_engine
 
def drop_version_control(url, repository, **opts):
 
    """%prog drop_version_control URL REPOSITORY_PATH
 

	
 
    Removes version control from a database.
 
    """
 
    engine = opts.pop('engine')
 
    schema = ControlledSchema(engine, repository)
 
    schema.drop()
 

	
 

	
 
def manage(file, **opts):
 
    """%prog manage FILENAME [VARIABLES...]
 

	
 
    Creates a script that runs Migrate with a set of default values.
 

	
 
    For example::
 

	
 
        %prog manage manage.py --repository=/path/to/repository \
 
--url=sqlite:///project.db
 

	
 
    would create the script manage.py. The following two commands
 
    would then have exactly the same results::
 

	
 
        python manage.py version
 
        %prog version --repository=/path/to/repository
 
    """
 
    Repository.create_manage_file(file, **opts)
 

	
 

	
 
@with_engine
 
def compare_model_to_db(url, repository, model, **opts):
 
    """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
 

	
 
    Compare the current model (assumed to be a module level variable
 
    of type sqlalchemy.MetaData) against the current database.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
 
    engine = opts.pop('engine')
 
    return ControlledSchema.compare_model_to_db(engine, model, repository)
 

	
 

	
 
@with_engine
 
def create_model(url, repository, **opts):
 
    """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
 

	
 
    Dump the current database as a Python model to stdout.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
 
    engine = opts.pop('engine')
 
    declarative = opts.get('declarative', False)
 
    return ControlledSchema.create_model(engine, repository, declarative)
 

	
 

	
 
@catch_known_errors
 
@with_engine
 
def make_update_script_for_model(url, repository, oldmodel, model, **opts):
 
    """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
 

	
 
    Create a script changing the old Python model to the new (current)
 
    Python model, sending to stdout.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
 
    engine = opts.pop('engine')
 
    return PythonScript.make_update_script_for_model(
 
        engine, oldmodel, model, repository, **opts)
 

	
 

	
 
@with_engine
 
def update_db_from_model(url, repository, model, **opts):
 
    """%prog update_db_from_model URL REPOSITORY_PATH MODEL
 

	
 
    Modify the database to match the structure of the current Python
 
    model. This also sets the db_version number to the latest in the
 
    repository.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
 
    engine = opts.pop('engine')
 
    schema = ControlledSchema(engine, repository)
 
    schema.update_db_from_model(model)
 

	
 
@with_engine
 
def _migrate(url, repository, version, upgrade, err, **opts):
 
    engine = opts.pop('engine')
 
    url = str(engine.url)
 
    schema = ControlledSchema(engine, repository)
 
    version = _migrate_version(schema, version, upgrade, err)
 

	
 
    changeset = schema.changeset(version)
 
    for ver, change in changeset:
 
        nextver = ver + changeset.step
 
        log.info('%s -> %s... ', ver, nextver)
 

	
 
        if opts.get('preview_sql'):
 
            if isinstance(change, PythonScript):
 
                log.info(change.preview_sql(url, changeset.step, **opts))
 
            elif isinstance(change, SqlScript):
 
                log.info(change.source())
 

	
 
        elif opts.get('preview_py'):
 
            if not isinstance(change, PythonScript):
 
                raise exceptions.UsageError("Python source can be only displayed"
 
                    " for python migration files")
 
            source_ver = max(ver, nextver)
 
            module = schema.repository.version(source_ver).script().module
 
            funcname = upgrade and "upgrade" or "downgrade"
 
            func = getattr(module, funcname)
 
            log.info(inspect.getsource(func))
 
        else:
 
            schema.runchange(ver, change, changeset.step)
 
            log.info('done')
 

	
 

	
 
def _migrate_version(schema, version, upgrade, err):
 
    if version is None:
 
        return version
 
    # Version is specified: ensure we're upgrading in the right direction
 
    # (current version < target version for upgrading; reverse for down)
 
    version = VerNum(version)
 
    cur = schema.version
 
    if upgrade is not None:
 
        if upgrade:
 
            direction = cur <= version
 
        else:
 
            direction = cur >= version
 
        if not direction:
 
            raise exceptions.KnownError(err % (cur, version))
 
    return version
rhodecode/lib/dbmigrate/migrate/versioning/cfgparse.py
Show inline comments
 
new file 100644
 
"""
 
   Configuration parser module.
 
"""
 

	
 
from ConfigParser import ConfigParser
 

	
 
from migrate.versioning.config import *
 
from migrate.versioning import pathed
 

	
 

	
 
class Parser(ConfigParser):
 
    """A project configuration file."""
 

	
 
    def to_dict(self, sections=None):
 
        """It's easier to access config values like dictionaries"""
 
        return self._sections
 

	
 

	
 
class Config(pathed.Pathed, Parser):
 
    """Configuration class."""
 

	
 
    def __init__(self, path, *p, **k):
 
        """Confirm the config file exists; read it."""
 
        self.require_found(path)
 
        pathed.Pathed.__init__(self, path)
 
        Parser.__init__(self, *p, **k)
 
        self.read(path)
rhodecode/lib/dbmigrate/migrate/versioning/config.py
Show inline comments
 
new file 100644
 
#!/usr/bin/python
 
# -*- coding: utf-8 -*-
 

	
 
from sqlalchemy.util import OrderedDict
 

	
 

	
 
__all__ = ['databases', 'operations']
 

	
 
databases = ('sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird')
 

	
 
# Map operation names to function names
 
operations = OrderedDict()
 
operations['upgrade'] = 'upgrade'
 
operations['downgrade'] = 'downgrade'
rhodecode/lib/dbmigrate/migrate/versioning/genmodel.py
Show inline comments
 
new file 100644
 
"""
 
   Code to generate a Python model from a database or differences
 
   between a model and database.
 

	
 
   Some of this is borrowed heavily from the AutoCode project at:
 
   http://code.google.com/p/sqlautocode/
 
"""
 

	
 
import sys
 
import logging
 

	
 
import sqlalchemy
 

	
 
import migrate
 
import migrate.changeset
 

	
 

	
 
log = logging.getLogger(__name__)
 
HEADER = """
 
## File autogenerated by genmodel.py
 

	
 
from sqlalchemy import *
 
meta = MetaData()
 
"""
 

	
 
DECLARATIVE_HEADER = """
 
## File autogenerated by genmodel.py
 

	
 
from sqlalchemy import *
 
from sqlalchemy.ext import declarative
 

	
 
Base = declarative.declarative_base()
 
"""
 

	
 

	
 
class ModelGenerator(object):
 

	
 
    def __init__(self, diff, engine, declarative=False):
 
        self.diff = diff
 
        self.engine = engine
 
        self.declarative = declarative
 

	
 
    def column_repr(self, col):
 
        kwarg = []
 
        if col.key != col.name:
 
            kwarg.append('key')
 
        if col.primary_key:
 
            col.primary_key = True  # otherwise it dumps it as 1
 
            kwarg.append('primary_key')
 
        if not col.nullable:
 
            kwarg.append('nullable')
 
        if col.onupdate:
 
            kwarg.append('onupdate')
 
        if col.default:
 
            if col.primary_key:
 
                # I found that PostgreSQL automatically creates a
 
                # default value for the sequence, but let's not show
 
                # that.
 
                pass
 
            else:
 
                kwarg.append('default')
 
        ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
 

	
 
        # crs: not sure if this is good idea, but it gets rid of extra
 
        # u''
 
        name = col.name.encode('utf8')
 

	
 
        type_ = col.type
 
        for cls in col.type.__class__.__mro__:
 
            if cls.__module__ == 'sqlalchemy.types' and \
 
                not cls.__name__.isupper():
 
                if cls is not type_.__class__:
 
                    type_ = cls()
 
                break
 

	
 
        data = {
 
            'name': name,
 
            'type': type_,
 
            'constraints': ', '.join([repr(cn) for cn in col.constraints]),
 
            'args': ks and ks or ''}
 

	
 
        if data['constraints']:
 
            if data['args']:
 
                data['args'] = ',' + data['args']
 

	
 
        if data['constraints'] or data['args']:
 
            data['maybeComma'] = ','
 
        else:
 
            data['maybeComma'] = ''
 

	
 
        commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
 
        commonStuff = commonStuff.strip()
 
        data['commonStuff'] = commonStuff
 
        if self.declarative:
 
            return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
 
        else:
 
            return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
 

	
 
    def getTableDefn(self, table):
 
        out = []
 
        tableName = table.name
 
        if self.declarative:
 
            out.append("class %(table)s(Base):" % {'table': tableName})
 
            out.append("  __tablename__ = '%(table)s'" % {'table': tableName})
 
            for col in table.columns:
 
                out.append("  %s" % self.column_repr(col))
 
        else:
 
            out.append("%(table)s = Table('%(table)s', meta," % \
 
                           {'table': tableName})
 
            for col in table.columns:
 
                out.append("  %s," % self.column_repr(col))
 
            out.append(")")
 
        return out
 

	
 
    def _get_tables(self,missingA=False,missingB=False,modified=False):
 
        to_process = []
 
        for bool_,names,metadata in (
 
            (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
 
            (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
 
            (modified,self.diff.tables_different,self.diff.metadataA),
 
                ):
 
            if bool_:
 
                for name in names:
 
                    yield metadata.tables.get(name)
 
        
 
    def toPython(self):
 
        """Assume database is current and model is empty."""
 
        out = []
 
        if self.declarative:
 
            out.append(DECLARATIVE_HEADER)
 
        else:
 
            out.append(HEADER)
 
        out.append("")
 
        for table in self._get_tables(missingA=True):
 
            out.extend(self.getTableDefn(table))
 
            out.append("")
 
        return '\n'.join(out)
 

	
 
    def toUpgradeDowngradePython(self, indent='    '):
 
        ''' Assume model is most current and database is out-of-date. '''
 
        decls = ['from migrate.changeset import schema',
 
                 'meta = MetaData()']
 
        for table in self._get_tables(
 
            missingA=True,missingB=True,modified=True
 
            ):
 
            decls.extend(self.getTableDefn(table))
 

	
 
        upgradeCommands, downgradeCommands = [], []
 
        for tableName in self.diff.tables_missing_from_A:
 
            upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
 
            downgradeCommands.append("%(table)s.create()" % \
 
                                         {'table': tableName})
 
        for tableName in self.diff.tables_missing_from_B:
 
            upgradeCommands.append("%(table)s.create()" % {'table': tableName})
 
            downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
 

	
 
        for tableName in self.diff.tables_different:
 
            dbTable = self.diff.metadataB.tables[tableName]
 
            missingInDatabase, missingInModel, diffDecl = \
 
                self.diff.colDiffs[tableName]
 
            for col in missingInDatabase:
 
                upgradeCommands.append('%s.columns[%r].create()' % (
 
                        modelTable, col.name))
 
                downgradeCommands.append('%s.columns[%r].drop()' % (
 
                        modelTable, col.name))
 
            for col in missingInModel:
 
                upgradeCommands.append('%s.columns[%r].drop()' % (
 
                        modelTable, col.name))
 
                downgradeCommands.append('%s.columns[%r].create()' % (
 
                        modelTable, col.name))
 
            for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
 
                upgradeCommands.append(
 
                    'assert False, "Can\'t alter columns: %s:%s=>%s"',
 
                    modelTable, modelCol.name, databaseCol.name)
 
                downgradeCommands.append(
 
                    'assert False, "Can\'t alter columns: %s:%s=>%s"',
 
                    modelTable, modelCol.name, databaseCol.name)
 
        pre_command = '    meta.bind = migrate_engine'
 

	
 
        return (
 
            '\n'.join(decls),
 
            '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
 
            '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
 

	
 
    def _db_can_handle_this_change(self,td):
 
        if (td.columns_missing_from_B
 
            and not td.columns_missing_from_A
 
            and not td.columns_different):
 
            # Even sqlite can handle this.
 
            return True
 
        else:
 
            return not self.engine.url.drivername.startswith('sqlite')
 

	
 
    def applyModel(self):
 
        """Apply model to current database."""
 

	
 
        meta = sqlalchemy.MetaData(self.engine)
 

	
 
        for table in self._get_tables(missingA=True):
 
            table = table.tometadata(meta)
 
            table.drop()
 
        for table in self._get_tables(missingB=True):
 
            table = table.tometadata(meta)
 
            table.create()
 
        for modelTable in self._get_tables(modified=True):
 
            tableName = modelTable.name
 
            modelTable = modelTable.tometadata(meta)
 
            dbTable = self.diff.metadataB.tables[tableName]
 

	
 
            td = self.diff.tables_different[tableName]
 
            
 
            if self._db_can_handle_this_change(td):
 
                
 
                for col in td.columns_missing_from_B:
 
                    modelTable.columns[col].create()
 
                for col in td.columns_missing_from_A:
 
                    dbTable.columns[col].drop()
 
                # XXX handle column changes here.
 
            else:
 
                # Sqlite doesn't support drop column, so you have to
 
                # do more: create temp table, copy data to it, drop
 
                # old table, create new table, copy data back.
 
                #
 
                # I wonder if this is guaranteed to be unique?
 
                tempName = '_temp_%s' % modelTable.name
 

	
 
                def getCopyStatement():
 
                    preparer = self.engine.dialect.preparer
 
                    commonCols = []
 
                    for modelCol in modelTable.columns:
 
                        if modelCol.name in dbTable.columns:
 
                            commonCols.append(modelCol.name)
 
                    commonColsStr = ', '.join(commonCols)
 
                    return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
 
                        (tableName, commonColsStr, commonColsStr, tempName)
 

	
 
                # Move the data in one transaction, so that we don't
 
                # leave the database in a nasty state.
 
                connection = self.engine.connect()
 
                trans = connection.begin()
 
                try:
 
                    connection.execute(
 
                        'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
 
                            (tempName, modelTable.name))
 
                    # make sure the drop takes place inside our
 
                    # transaction with the bind parameter
 
                    modelTable.drop(bind=connection)
 
                    modelTable.create(bind=connection)
 
                    connection.execute(getCopyStatement())
 
                    connection.execute('DROP TABLE %s' % tempName)
 
                    trans.commit()
 
                except:
 
                    trans.rollback()
 
                    raise
rhodecode/lib/dbmigrate/migrate/versioning/migrate_repository.py
Show inline comments
 
new file 100644
 
"""
 
   Script to migrate repository from sqlalchemy <= 0.4.4 to the new
 
   repository schema. This shouldn't use any other migrate modules, so
 
   that it can work in any version.
 
"""
 

	
 
import os
 
import sys
 
import logging
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
def usage():
 
    """Gives usage information."""
 
    print """Usage: %(prog)s repository-to-migrate
 

	
 
    Upgrade your repository to the new flat format.
 

	
 
    NOTE: You should probably make a backup before running this.
 
    """ % {'prog': sys.argv[0]}
 

	
 
    sys.exit(1)
 

	
 

	
 
def delete_file(filepath):
 
    """Deletes a file and prints a message."""
 
    log.info('Deleting file: %s' % filepath)
 
    os.remove(filepath)
 

	
 

	
 
def move_file(src, tgt):
 
    """Moves a file and prints a message."""
 
    log.info('Moving file %s to %s' % (src, tgt))
 
    if os.path.exists(tgt):
 
        raise Exception(
 
            'Cannot move file %s because target %s already exists' % \
 
                (src, tgt))
 
    os.rename(src, tgt)
 

	
 

	
 
def delete_directory(dirpath):
 
    """Delete a directory and print a message."""
 
    log.info('Deleting directory: %s' % dirpath)
 
    os.rmdir(dirpath)
 

	
 

	
 
def migrate_repository(repos):
 
    """Does the actual migration to the new repository format."""
 
    log.info('Migrating repository at: %s to new format' % repos)
 
    versions = '%s/versions' % repos
 
    dirs = os.listdir(versions)
 
    # Only use int's in list.
 
    numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()]
 
    numdirs.sort()  # Sort list.
 
    for dirname in numdirs:
 
        origdir = '%s/%s' % (versions, dirname)
 
        log.info('Working on directory: %s' % origdir)
 
        files = os.listdir(origdir)
 
        files.sort()
 
        for filename in files:
 
            # Delete compiled Python files.
 
            if filename.endswith('.pyc') or filename.endswith('.pyo'):
 
                delete_file('%s/%s' % (origdir, filename))
 

	
 
            # Delete empty __init__.py files.
 
            origfile = '%s/__init__.py' % origdir
 
            if os.path.exists(origfile) and len(open(origfile).read()) == 0:
 
                delete_file(origfile)
 

	
 
            # Move sql upgrade scripts.
 
            if filename.endswith('.sql'):
 
                version, dbms, operation = filename.split('.', 3)[0:3]
 
                origfile = '%s/%s' % (origdir, filename)
 
                # For instance:  2.postgres.upgrade.sql ->
 
                #  002_postgres_upgrade.sql
 
                tgtfile = '%s/%03d_%s_%s.sql' % (
 
                    versions, int(version), dbms, operation)
 
                move_file(origfile, tgtfile)
 

	
 
        # Move Python upgrade script.
 
        pyfile = '%s.py' % dirname
 
        pyfilepath = '%s/%s' % (origdir, pyfile)
 
        if os.path.exists(pyfilepath):
 
            tgtfile = '%s/%03d.py' % (versions, int(dirname))
 
            move_file(pyfilepath, tgtfile)
 

	
 
        # Try to remove directory. Will fail if it's not empty.
 
        delete_directory(origdir)
 

	
 

	
 
def main():
 
    """Main function to be called when using this script."""
 
    if len(sys.argv) != 2:
 
        usage()
 
    migrate_repository(sys.argv[1])
 

	
 

	
 
if __name__ == '__main__':
 
    main()
rhodecode/lib/dbmigrate/migrate/versioning/pathed.py
Show inline comments
 
new file 100644
 
"""
 
   A path/directory class.
 
"""
 

	
 
import os
 
import shutil
 
import logging
 

	
 
from migrate import exceptions
 
from migrate.versioning.config import *
 
from migrate.versioning.util import KeyedInstance
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class Pathed(KeyedInstance):
 
    """
 
    A class associated with a path/directory tree.
 

	
 
    Only one instance of this class may exist for a particular file;
 
    __new__ will return an existing instance if possible
 
    """
 
    parent = None
 

	
 
    @classmethod
 
    def _key(cls, path):
 
        return str(path)
 

	
 
    def __init__(self, path):
 
        self.path = path
 
        if self.__class__.parent is not None:
 
            self._init_parent(path)
 

	
 
    def _init_parent(self, path):
 
        """Try to initialize this object's parent, if it has one"""
 
        parent_path = self.__class__._parent_path(path)
 
        self.parent = self.__class__.parent(parent_path)
 
        log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path))
 
        self.parent._init_child(path, self)
 

	
 
    def _init_child(self, child, path):
 
        """Run when a child of this object is initialized.
 

	
 
        Parameters: the child object; the path to this object (its
 
        parent)
 
        """
 

	
 
    @classmethod
 
    def _parent_path(cls, path):
 
        """
 
        Fetch the path of this object's parent from this object's path.
 
        """
 
        # os.path.dirname(), but strip directories like files (like
 
        # unix basename)
 
        #
 
        # Treat directories like files...
 
        if path[-1] == '/':
 
            path = path[:-1]
 
        ret = os.path.dirname(path)
 
        return ret
 

	
 
    @classmethod
 
    def require_notfound(cls, path):
 
        """Ensures a given path does not already exist"""
 
        if os.path.exists(path):
 
            raise exceptions.PathFoundError(path)
 

	
 
    @classmethod
 
    def require_found(cls, path):
 
        """Ensures a given path already exists"""
 
        if not os.path.exists(path):
 
            raise exceptions.PathNotFoundError(path)
 

	
 
    def __str__(self):
 
        return self.path
rhodecode/lib/dbmigrate/migrate/versioning/repository.py
Show inline comments
 
new file 100644
 
"""
 
   SQLAlchemy migrate repository management.
 
"""
 
import os
 
import shutil
 
import string
 
import logging
 

	
 
from pkg_resources import resource_filename
 
from tempita import Template as TempitaTemplate
 

	
 
from migrate import exceptions
 
from migrate.versioning import version, pathed, cfgparse
 
from migrate.versioning.template import Template
 
from migrate.versioning.config import *
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class Changeset(dict):
 
    """A collection of changes to be applied to a database.
 

	
 
    Changesets are bound to a repository and manage a set of
 
    scripts from that repository.
 

	
 
    Behaves like a dict, for the most part. Keys are ordered based on step value.
 
    """
 

	
 
    def __init__(self, start, *changes, **k):
 
        """
 
        Give a start version; step must be explicitly stated.
 
        """
 
        self.step = k.pop('step', 1)
 
        self.start = version.VerNum(start)
 
        self.end = self.start
 
        for change in changes:
 
            self.add(change)
 

	
 
    def __iter__(self):
 
        return iter(self.items())
 

	
 
    def keys(self):
 
        """
 
        In a series of upgrades x -> y, keys are version x. Sorted.
 
        """
 
        ret = super(Changeset, self).keys()
 
        # Reverse order if downgrading
 
        ret.sort(reverse=(self.step < 1))
 
        return ret
 

	
 
    def values(self):
 
        return [self[k] for k in self.keys()]
 

	
 
    def items(self):
 
        return zip(self.keys(), self.values())
 

	
 
    def add(self, change):
 
        """Add new change to changeset"""
 
        key = self.end
 
        self.end += self.step
 
        self[key] = change
 

	
 
    def run(self, *p, **k):
 
        """Run the changeset scripts"""
 
        for version, script in self:
 
            script.run(*p, **k)
 

	
 

	
 
class Repository(pathed.Pathed):
 
    """A project's change script repository"""
 

	
 
    _config = 'migrate.cfg'
 
    _versions = 'versions'
 

	
 
    def __init__(self, path):
 
        log.debug('Loading repository %s...' % path)
 
        self.verify(path)
 
        super(Repository, self).__init__(path)
 
        self.config = cfgparse.Config(os.path.join(self.path, self._config))
 
        self.versions = version.Collection(os.path.join(self.path,
 
                                                      self._versions))
 
        log.debug('Repository %s loaded successfully' % path)
 
        log.debug('Config: %r' % self.config.to_dict())
 

	
 
    @classmethod
 
    def verify(cls, path):
 
        """
 
        Ensure the target path is a valid repository.
 

	
 
        :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
 
        """
 
        # Ensure the existence of required files
 
        try:
 
            cls.require_found(path)
 
            cls.require_found(os.path.join(path, cls._config))
 
            cls.require_found(os.path.join(path, cls._versions))
 
        except exceptions.PathNotFoundError, e:
 
            raise exceptions.InvalidRepositoryError(path)
 

	
 
    @classmethod
 
    def prepare_config(cls, tmpl_dir, name, options=None):
 
        """
 
        Prepare a project configuration file for a new project.
 

	
 
        :param tmpl_dir: Path to Repository template
 
        :param config_file: Name of the config file in Repository template
 
        :param name: Repository name
 
        :type tmpl_dir: string
 
        :type config_file: string
 
        :type name: string
 
        :returns: Populated config file
 
        """
 
        if options is None:
 
            options = {}
 
        options.setdefault('version_table', 'migrate_version')
 
        options.setdefault('repository_id', name)
 
        options.setdefault('required_dbs', [])
 

	
 
        tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
 
        ret = TempitaTemplate(tmpl).substitute(options)
 

	
 
        # cleanup
 
        del options['__template_name__']
 

	
 
        return ret
 

	
 
    @classmethod
 
    def create(cls, path, name, **opts):
 
        """Create a repository at a specified path"""
 
        cls.require_notfound(path)
 
        theme = opts.pop('templates_theme', None)
 
        t_path = opts.pop('templates_path', None)
 

	
 
        # Create repository
 
        tmpl_dir = Template(t_path).get_repository(theme=theme)
 
        shutil.copytree(tmpl_dir, path)
 

	
 
        # Edit config defaults
 
        config_text = cls.prepare_config(tmpl_dir, name, options=opts)
 
        fd = open(os.path.join(path, cls._config), 'w')
 
        fd.write(config_text)
 
        fd.close()
 

	
 
        opts['repository_name'] = name
 

	
 
        # Create a management script
 
        manager = os.path.join(path, 'manage.py')
 
        Repository.create_manage_file(manager, templates_theme=theme,
 
            templates_path=t_path, **opts)
 

	
 
        return cls(path)
 

	
 
    def create_script(self, description, **k):
 
        """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
 
        self.versions.create_new_python_version(description, **k)
 

	
 
    def create_script_sql(self, database, **k):
 
        """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
 
        self.versions.create_new_sql_version(database, **k)
 

	
 
    @property
 
    def latest(self):
 
        """API to :attr:`migrate.versioning.version.Collection.latest`"""
 
        return self.versions.latest
 

	
 
    @property
 
    def version_table(self):
 
        """Returns version_table name specified in config"""
 
        return self.config.get('db_settings', 'version_table')
 

	
 
    @property
 
    def id(self):
 
        """Returns repository id specified in config"""
 
        return self.config.get('db_settings', 'repository_id')
 

	
 
    def version(self, *p, **k):
 
        """API to :attr:`migrate.versioning.version.Collection.version`"""
 
        return self.versions.version(*p, **k)
 

	
 
    @classmethod
 
    def clear(cls):
 
        # TODO: deletes repo
 
        super(Repository, cls).clear()
 
        version.Collection.clear()
 

	
 
    def changeset(self, database, start, end=None):
 
        """Create a changeset to migrate this database from ver. start to end/latest.
 

	
 
        :param database: name of database to generate changeset
 
        :param start: version to start at
 
        :param end: version to end at (latest if None given)
 
        :type database: string
 
        :type start: int
 
        :type end: int
 
        :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
 
        """
 
        start = version.VerNum(start)
 

	
 
        if end is None:
 
            end = self.latest
 
        else:
 
            end = version.VerNum(end)
 

	
 
        if start <= end:
 
            step = 1
 
            range_mod = 1
 
            op = 'upgrade'
 
        else:
 
            step = -1
 
            range_mod = 0
 
            op = 'downgrade'
 

	
 
        versions = range(start + range_mod, end + range_mod, step)
 
        changes = [self.version(v).script(database, op) for v in versions]
 
        ret = Changeset(start, step=step, *changes)
 
        return ret
 

	
 
    @classmethod
 
    def create_manage_file(cls, file_, **opts):
 
        """Create a project management script (manage.py)
 
        
 
        :param file_: Destination file to be written
 
        :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
 
        """
 
        mng_file = Template(opts.pop('templates_path', None))\
 
            .get_manage(theme=opts.pop('templates_theme', None))
 

	
 
        tmpl = open(mng_file).read()
 
        fd = open(file_, 'w')
 
        fd.write(TempitaTemplate(tmpl).substitute(opts))
 
        fd.close()
rhodecode/lib/dbmigrate/migrate/versioning/schema.py
Show inline comments
 
new file 100644
 
"""
 
   Database schema version management.
 
"""
 
import sys
 
import logging
 

	
 
from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
 
    create_engine)
 
from sqlalchemy.sql import and_
 
from sqlalchemy import exceptions as sa_exceptions
 
from sqlalchemy.sql import bindparam
 

	
 
from migrate import exceptions
 
from migrate.versioning import genmodel, schemadiff
 
from migrate.versioning.repository import Repository
 
from migrate.versioning.util import load_model
 
from migrate.versioning.version import VerNum
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class ControlledSchema(object):
 
    """A database under version control"""
 

	
 
    def __init__(self, engine, repository):
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 
        self.engine = engine
 
        self.repository = repository
 
        self.meta = MetaData(engine)
 
        self.load()
 

	
 
    def __eq__(self, other):
 
        """Compare two schemas by repositories and versions"""
 
        return (self.repository is other.repository \
 
            and self.version == other.version)
 

	
 
    def load(self):
 
        """Load controlled schema version info from DB"""
 
        tname = self.repository.version_table
 
        try:
 
            if not hasattr(self, 'table') or self.table is None:
 
                    self.table = Table(tname, self.meta, autoload=True)
 

	
 
            result = self.engine.execute(self.table.select(
 
                self.table.c.repository_id == str(self.repository.id)))
 

	
 
            data = list(result)[0]
 
        except:
 
            cls, exc, tb = sys.exc_info()
 
            raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
 

	
 
        self.version = data['version']
 
        return data
 

	
 
    def drop(self):
 
        """
 
        Remove version control from a database.
 
        """
 
        try:
 
            self.table.drop()
 
        except (sa_exceptions.SQLError):
 
            raise exceptions.DatabaseNotControlledError(str(self.table))
 

	
 
    def changeset(self, version=None):
 
        """API to Changeset creation.
 
        
 
        Uses self.version for start version and engine.name
 
        to get database name.
 
        """
 
        database = self.engine.name
 
        start_ver = self.version
 
        changeset = self.repository.changeset(database, start_ver, version)
 
        return changeset
 

	
 
    def runchange(self, ver, change, step):
 
        startver = ver
 
        endver = ver + step
 
        # Current database version must be correct! Don't run if corrupt!
 
        if self.version != startver:
 
            raise exceptions.InvalidVersionError("%s is not %s" % \
 
                                                     (self.version, startver))
 
        # Run the change
 
        change.run(self.engine, step)
 

	
 
        # Update/refresh database version
 
        self.update_repository_table(startver, endver)
 
        self.load()
 

	
 
    def update_repository_table(self, startver, endver):
 
        """Update version_table with new information"""
 
        update = self.table.update(and_(self.table.c.version == int(startver),
 
             self.table.c.repository_id == str(self.repository.id)))
 
        self.engine.execute(update, version=int(endver))
 

	
 
    def upgrade(self, version=None):
 
        """
 
        Upgrade (or downgrade) to a specified version, or latest version.
 
        """
 
        changeset = self.changeset(version)
 
        for ver, change in changeset:
 
            self.runchange(ver, change, changeset.step)
 

	
 
    def update_db_from_model(self, model):
 
        """
 
        Modify the database to match the structure of the current Python model.
 
        """
 
        model = load_model(model)
 

	
 
        diff = schemadiff.getDiffOfModelAgainstDatabase(
 
            model, self.engine, excludeTables=[self.repository.version_table]
 
            )
 
        genmodel.ModelGenerator(diff,self.engine).applyModel()
 

	
 
        self.update_repository_table(self.version, int(self.repository.latest))
 

	
 
        self.load()
 

	
 
    @classmethod
 
    def create(cls, engine, repository, version=None):
 
        """
 
        Declare a database to be under a repository's version control.
 

	
 
        :raises: :exc:`DatabaseAlreadyControlledError`
 
        :returns: :class:`ControlledSchema`
 
        """
 
        # Confirm that the version # is valid: positive, integer,
 
        # exists in repos
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 
        version = cls._validate_version(repository, version)
 
        table = cls._create_table_version(engine, repository, version)
 
        # TODO: history table
 
        # Load repository information and return
 
        return cls(engine, repository)
 

	
 
    @classmethod
 
    def _validate_version(cls, repository, version):
 
        """
 
        Ensures this is a valid version number for this repository.
 

	
 
        :raises: :exc:`InvalidVersionError` if invalid
 
        :return: valid version number
 
        """
 
        if version is None:
 
            version = 0
 
        try:
 
            version = VerNum(version) # raises valueerror
 
            if version < 0 or version > repository.latest:
 
                raise ValueError()
 
        except ValueError:
 
            raise exceptions.InvalidVersionError(version)
 
        return version
 

	
 
    @classmethod
 
    def _create_table_version(cls, engine, repository, version):
 
        """
 
        Creates the versioning table in a database.
 

	
 
        :raises: :exc:`DatabaseAlreadyControlledError`
 
        """
 
        # Create tables
 
        tname = repository.version_table
 
        meta = MetaData(engine)
 

	
 
        table = Table(
 
            tname, meta,
 
            Column('repository_id', String(250), primary_key=True),
 
            Column('repository_path', Text),
 
            Column('version', Integer), )
 

	
 
        # there can be multiple repositories/schemas in the same db
 
        if not table.exists():
 
            table.create()
 

	
 
        # test for existing repository_id
 
        s = table.select(table.c.repository_id == bindparam("repository_id"))
 
        result = engine.execute(s, repository_id=repository.id)
 
        if result.fetchone():
 
            raise exceptions.DatabaseAlreadyControlledError
 

	
 
        # Insert data
 
        engine.execute(table.insert().values(
 
                           repository_id=repository.id,
 
                           repository_path=repository.path,
 
                           version=int(version)))
 
        return table
 

	
 
    @classmethod
 
    def compare_model_to_db(cls, engine, model, repository):
 
        """
 
        Compare the current model against the current database.
 
        """
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 
        model = load_model(model)
 

	
 
        diff = schemadiff.getDiffOfModelAgainstDatabase(
 
            model, engine, excludeTables=[repository.version_table])
 
        return diff
 

	
 
    @classmethod
 
    def create_model(cls, engine, repository, declarative=False):
 
        """
 
        Dump the current database as a Python model.
 
        """
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 

	
 
        diff = schemadiff.getDiffOfModelAgainstDatabase(
 
            MetaData(), engine, excludeTables=[repository.version_table]
 
            )
 
        return genmodel.ModelGenerator(diff, engine, declarative).toPython()
rhodecode/lib/dbmigrate/migrate/versioning/schemadiff.py
Show inline comments
 
new file 100644
 
"""
 
   Schema differencing support.
 
"""
 

	
 
import logging
 
import sqlalchemy
 

	
 
from migrate.changeset import SQLA_06
 
from sqlalchemy.types import Float
 

	
 
log = logging.getLogger(__name__)
 

	
 
def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
 
    """
 
    Return differences of model against database.
 

	
 
    :return: object which will evaluate to :keyword:`True` if there \
 
      are differences else :keyword:`False`.
 
    """
 
    return SchemaDiff(metadata,
 
                      sqlalchemy.MetaData(engine, reflect=True),
 
                      labelA='model',
 
                      labelB='database',
 
                      excludeTables=excludeTables)
 

	
 

	
 
def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
 
    """
 
    Return differences of model against another model.
 

	
 
    :return: object which will evaluate to :keyword:`True` if there \
 
      are differences else :keyword:`False`.
 
    """
 
    return SchemaDiff(metadataA, metadataB, excludeTables)
 

	
 

	
 
class ColDiff(object):
 
    """
 
    Container for differences in one :class:`~sqlalchemy.schema.Column`
 
    between two :class:`~sqlalchemy.schema.Table` instances, ``A``
 
    and ``B``.
 
    
 
    .. attribute:: col_A
 

	
 
      The :class:`~sqlalchemy.schema.Column` object for A.
 
      
 
    .. attribute:: col_B
 

	
 
      The :class:`~sqlalchemy.schema.Column` object for B.
 

	
 
    .. attribute:: type_A
 

	
 
      The most generic type of the :class:`~sqlalchemy.schema.Column`
 
      object in A. 
 
      
 
    .. attribute:: type_B
 

	
 
      The most generic type of the :class:`~sqlalchemy.schema.Column`
 
      object in A. 
 
      
 
    """
 
    
 
    diff = False
 

	
 
    def __init__(self,col_A,col_B):
 
        self.col_A = col_A
 
        self.col_B = col_B
 

	
 
        self.type_A = col_A.type
 
        self.type_B = col_B.type
 

	
 
        self.affinity_A = self.type_A._type_affinity
 
        self.affinity_B = self.type_B._type_affinity
 

	
 
        if self.affinity_A is not self.affinity_B:
 
            self.diff = True
 
            return
 

	
 
        if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
 
            if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
 
                self.diff=True
 
                return
 

	
 
        for attr in ('precision','scale','length'):
 
            A = getattr(self.type_A,attr,None)
 
            B = getattr(self.type_B,attr,None)
 
            if not (A is None or B is None) and A!=B:
 
                self.diff=True
 
                return
 
        
 
    def __nonzero__(self):
 
        return self.diff
 
    
 
class TableDiff(object):
 
    """
 
    Container for differences in one :class:`~sqlalchemy.schema.Table`
 
    between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
 
    and ``B``.
 

	
 
    .. attribute:: columns_missing_from_A
 

	
 
      A sequence of column names that were found in B but weren't in
 
      A.
 
      
 
    .. attribute:: columns_missing_from_B
 

	
 
      A sequence of column names that were found in A but weren't in
 
      B.
 
      
 
    .. attribute:: columns_different
 

	
 
      A dictionary containing information about columns that were
 
      found to be different.
 
      It maps column names to a :class:`ColDiff` objects describing the
 
      differences found.
 
    """
 
    __slots__ = (
 
        'columns_missing_from_A',
 
        'columns_missing_from_B',
 
        'columns_different',
 
        )
 

	
 
    def __nonzero__(self):
 
        return bool(
 
            self.columns_missing_from_A or
 
            self.columns_missing_from_B or
 
            self.columns_different
 
            )
 
    
 
class SchemaDiff(object):
 
    """
 
    Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
 
    objects.
 

	
 
    The string representation of a :class:`SchemaDiff` will summarise
 
    the changes found between the two
 
    :class:`~sqlalchemy.schema.MetaData` objects.
 

	
 
    The length of a :class:`SchemaDiff` will give the number of
 
    changes found, enabling it to be used much like a boolean in
 
    expressions.
 
        
 
    :param metadataA:
 
      First :class:`~sqlalchemy.schema.MetaData` to compare.
 
      
 
    :param metadataB:
 
      Second :class:`~sqlalchemy.schema.MetaData` to compare.
 
      
 
    :param labelA:
 
      The label to use in messages about the first
 
      :class:`~sqlalchemy.schema.MetaData`. 
 
    
 
    :param labelB: 
 
      The label to use in messages about the second
 
      :class:`~sqlalchemy.schema.MetaData`. 
 
    
 
    :param excludeTables:
 
      A sequence of table names to exclude.
 
      
 
    .. attribute:: tables_missing_from_A
 

	
 
      A sequence of table names that were found in B but weren't in
 
      A.
 
      
 
    .. attribute:: tables_missing_from_B
 

	
 
      A sequence of table names that were found in A but weren't in
 
      B.
 
      
 
    .. attribute:: tables_different
 

	
 
      A dictionary containing information about tables that were found
 
      to be different.
 
      It maps table names to a :class:`TableDiff` objects describing the
 
      differences found.
 
    """
 

	
 
    def __init__(self,
 
                 metadataA, metadataB,
 
                 labelA='metadataA',
 
                 labelB='metadataB',
 
                 excludeTables=None):
 

	
 
        self.metadataA, self.metadataB = metadataA, metadataB
 
        self.labelA, self.labelB = labelA, labelB
 
        self.label_width = max(len(labelA),len(labelB))
 
        excludeTables = set(excludeTables or [])
 

	
 
        A_table_names = set(metadataA.tables.keys())
 
        B_table_names = set(metadataB.tables.keys())
 

	
 
        self.tables_missing_from_A = sorted(
 
            B_table_names - A_table_names - excludeTables
 
            )
 
        self.tables_missing_from_B = sorted(
 
            A_table_names - B_table_names - excludeTables
 
            )
 
        
 
        self.tables_different = {}
 
        for table_name in A_table_names.intersection(B_table_names):
 

	
 
            td = TableDiff()
 
            
 
            A_table = metadataA.tables[table_name]
 
            B_table = metadataB.tables[table_name]
 
            
 
            A_column_names = set(A_table.columns.keys())
 
            B_column_names = set(B_table.columns.keys())
 

	
 
            td.columns_missing_from_A = sorted(
 
                B_column_names - A_column_names
 
                )
 
            
 
            td.columns_missing_from_B = sorted(
 
                A_column_names - B_column_names
 
                )
 
            
 
            td.columns_different = {}
 

	
 
            for col_name in A_column_names.intersection(B_column_names):
 

	
 
                cd = ColDiff(
 
                    A_table.columns.get(col_name),
 
                    B_table.columns.get(col_name)
 
                    )
 

	
 
                if cd:
 
                    td.columns_different[col_name]=cd
 
                
 
            # XXX - index and constraint differences should
 
            #       be checked for here
 

	
 
            if td:
 
                self.tables_different[table_name]=td
 

	
 
    def __str__(self):
 
        ''' Summarize differences. '''
 
        out = []
 
        column_template ='      %%%is: %%r' % self.label_width
 
        
 
        for names,label in (
 
            (self.tables_missing_from_A,self.labelA),
 
            (self.tables_missing_from_B,self.labelB),
 
            ):
 
            if names:
 
                out.append(
 
                    '  tables missing from %s: %s' % (
 
                        label,', '.join(sorted(names))
 
                        )
 
                    )
 
                
 
        for name,td in sorted(self.tables_different.items()):
 
            out.append(
 
               '  table with differences: %s' % name
 
               )
 
            for names,label in (
 
                (td.columns_missing_from_A,self.labelA),
 
                (td.columns_missing_from_B,self.labelB),
 
                ):
 
                if names:
 
                    out.append(
 
                        '    %s missing these columns: %s' % (
 
                            label,', '.join(sorted(names))
 
                            )
 
                        )
 
            for name,cd in td.columns_different.items():
 
                out.append('    column with differences: %s' % name)
 
                out.append(column_template % (self.labelA,cd.col_A))
 
                out.append(column_template % (self.labelB,cd.col_B))
 
                
 
        if out:
 
            out.insert(0, 'Schema diffs:')
 
            return '\n'.join(out)
 
        else:
 
            return 'No schema diffs'
 

	
 
    def __len__(self):
 
        """
 
        Used in bool evaluation, return of 0 means no diffs.
 
        """
 
        return (
 
            len(self.tables_missing_from_A) +
 
            len(self.tables_missing_from_B) +
 
            len(self.tables_different)
 
            )
rhodecode/lib/dbmigrate/migrate/versioning/script/__init__.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
from migrate.versioning.script.base import BaseScript
 
from migrate.versioning.script.py import PythonScript
 
from migrate.versioning.script.sql import SqlScript
rhodecode/lib/dbmigrate/migrate/versioning/script/base.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 
import logging
 

	
 
from migrate import exceptions
 
from migrate.versioning.config import operations
 
from migrate.versioning import pathed
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class BaseScript(pathed.Pathed):
 
    """Base class for other types of scripts.
 
    All scripts have the following properties:
 

	
 
    source (script.source())
 
      The source code of the script
 
    version (script.version())
 
      The version number of the script
 
    operations (script.operations())
 
      The operations defined by the script: upgrade(), downgrade() or both.
 
      Returns a tuple of operations.
 
      Can also check for an operation with ex. script.operation(Script.ops.up)
 
    """ # TODO: sphinxfy this and implement it correctly
 

	
 
    def __init__(self, path):
 
        log.debug('Loading script %s...' % path)
 
        self.verify(path)
 
        super(BaseScript, self).__init__(path)
 
        log.debug('Script %s loaded successfully' % path)
 
    
 
    @classmethod
 
    def verify(cls, path):
 
        """Ensure this is a valid script
 
        This version simply ensures the script file's existence
 

	
 
        :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
 
        """
 
        try:
 
            cls.require_found(path)
 
        except:
 
            raise exceptions.InvalidScriptError(path)
 

	
 
    def source(self):
 
        """:returns: source code of the script.
 
        :rtype: string
 
        """
 
        fd = open(self.path)
 
        ret = fd.read()
 
        fd.close()
 
        return ret
 

	
 
    def run(self, engine):
 
        """Core of each BaseScript subclass.
 
        This method executes the script.
 
        """
 
        raise NotImplementedError()
rhodecode/lib/dbmigrate/migrate/versioning/script/py.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
import shutil
 
import warnings
 
import logging
 
from StringIO import StringIO
 

	
 
import migrate
 
from migrate.versioning import genmodel, schemadiff
 
from migrate.versioning.config import operations
 
from migrate.versioning.template import Template
 
from migrate.versioning.script import base
 
from migrate.versioning.util import import_path, load_model, with_engine
 
from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
 

	
 
log = logging.getLogger(__name__)
 
__all__ = ['PythonScript']
 

	
 

	
 
class PythonScript(base.BaseScript):
 
    """Base for Python scripts"""
 

	
 
    @classmethod
 
    def create(cls, path, **opts):
 
        """Create an empty migration script at specified path
 
        
 
        :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
 
        cls.require_notfound(path)
 

	
 
        src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
 
        shutil.copy(src, path)
 

	
 
        return cls(path)
 

	
 
    @classmethod
 
    def make_update_script_for_model(cls, engine, oldmodel,
 
                                     model, repository, **opts):
 
        """Create a migration script based on difference between two SA models.
 
        
 
        :param repository: path to migrate repository
 
        :param oldmodel: dotted.module.name:SAClass or SAClass object
 
        :param model: dotted.module.name:SAClass or SAClass object
 
        :param engine: SQLAlchemy engine
 
        :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
 
        :type oldmodel: string or Class
 
        :type model: string or Class
 
        :type engine: Engine instance
 
        :returns: Upgrade / Downgrade script
 
        :rtype: string
 
        """
 
        
 
        if isinstance(repository, basestring):
 
            # oh dear, an import cycle!
 
            from migrate.versioning.repository import Repository
 
            repository = Repository(repository)
 

	
 
        oldmodel = load_model(oldmodel)
 
        model = load_model(model)
 

	
 
        # Compute differences.
 
        diff = schemadiff.getDiffOfModelAgainstModel(
 
            oldmodel,
 
            model,
 
            excludeTables=[repository.version_table])
 
        # TODO: diff can be False (there is no difference?)
 
        decls, upgradeCommands, downgradeCommands = \
 
            genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
 

	
 
        # Store differences into file.
 
        src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
 
        f = open(src)
 
        contents = f.read()
 
        f.close()
 

	
 
        # generate source
 
        search = 'def upgrade(migrate_engine):'
 
        contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
 
        if upgradeCommands:
 
            contents = contents.replace('    pass', upgradeCommands, 1)
 
        if downgradeCommands:
 
            contents = contents.replace('    pass', downgradeCommands, 1)
 
        return contents
 

	
 
    @classmethod
 
    def verify_module(cls, path):
 
        """Ensure path is a valid script
 
        
 
        :param path: Script location
 
        :type path: string
 
        :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
 
        :returns: Python module
 
        """
 
        # Try to import and get the upgrade() func
 
        module = import_path(path)
 
        try:
 
            assert callable(module.upgrade)
 
        except Exception, e:
 
            raise InvalidScriptError(path + ': %s' % str(e))
 
        return module
 

	
 
    def preview_sql(self, url, step, **args):
 
        """Mocks SQLAlchemy Engine to store all executed calls in a string 
 
        and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
 

	
 
        :returns: SQL file
 
        """
 
        buf = StringIO()
 
        args['engine_arg_strategy'] = 'mock'
 
        args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
 

	
 
        @with_engine
 
        def go(url, step, **kw):
 
            engine = kw.pop('engine')
 
            self.run(engine, step)
 
            return buf.getvalue()
 

	
 
        return go(url, step, **args)
 

	
 
    def run(self, engine, step):
 
        """Core method of Script file. 
 
        Exectues :func:`update` or :func:`downgrade` functions
 

	
 
        :param engine: SQLAlchemy Engine
 
        :param step: Operation to run
 
        :type engine: string
 
        :type step: int
 
        """
 
        if step > 0:
 
            op = 'upgrade'
 
        elif step < 0:
 
            op = 'downgrade'
 
        else:
 
            raise ScriptError("%d is not a valid step" % step)
 

	
 
        funcname = base.operations[op]
 
        script_func = self._func(funcname)
 

	
 
        try:
 
            script_func(engine)
 
        except TypeError:
 
            warnings.warn("upgrade/downgrade functions must accept engine"
 
                " parameter (since version > 0.5.4)", MigrateDeprecationWarning)
 
            raise
 

	
 
    @property
 
    def module(self):
 
        """Calls :meth:`migrate.versioning.script.py.verify_module`
 
        and returns it.
 
        """
 
        if not hasattr(self, '_module'):
 
            self._module = self.verify_module(self.path)
 
        return self._module
 

	
 
    def _func(self, funcname):
 
        if not hasattr(self.module, funcname):
 
            msg = "Function '%s' is not defined in this script"
 
            raise ScriptError(msg % funcname)
 
        return getattr(self.module, funcname)
rhodecode/lib/dbmigrate/migrate/versioning/script/sql.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 
import logging
 
import shutil
 

	
 
from migrate.versioning.script import base
 
from migrate.versioning.template import Template
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class SqlScript(base.BaseScript):
 
    """A file containing plain SQL statements."""
 

	
 
    @classmethod
 
    def create(cls, path, **opts):
 
        """Create an empty migration script at specified path
 
        
 
        :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
 
        cls.require_notfound(path)
 

	
 
        src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
 
        shutil.copy(src, path)
 
        return cls(path)
 

	
 
    # TODO: why is step parameter even here?
 
    def run(self, engine, step=None, executemany=True):
 
        """Runs SQL script through raw dbapi execute call"""
 
        text = self.source()
 
        # Don't rely on SA's autocommit here
 
        # (SA uses .startswith to check if a commit is needed. What if script
 
        # starts with a comment?)
 
        conn = engine.connect()
 
        try:
 
            trans = conn.begin()
 
            try:
 
                # HACK: SQLite doesn't allow multiple statements through
 
                # its execute() method, but it provides executescript() instead
 
                dbapi = conn.engine.raw_connection()
 
                if executemany and getattr(dbapi, 'executescript', None):
 
                    dbapi.executescript(text)
 
                else:
 
                    conn.execute(text)
 
                trans.commit()
 
            except:
 
                trans.rollback()
 
                raise
 
        finally:
 
            conn.close()
rhodecode/lib/dbmigrate/migrate/versioning/shell.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
"""The migrate command-line tool."""
 

	
 
import sys
 
import inspect
 
import logging
 
from optparse import OptionParser, BadOptionError
 

	
 
from migrate import exceptions
 
from migrate.versioning import api
 
from migrate.versioning.config import *
 
from migrate.versioning.util import asbool
 

	
 

	
 
alias = dict(
 
    s=api.script,
 
    vc=api.version_control,
 
    dbv=api.db_version,
 
    v=api.version,
 
)
 

	
 
def alias_setup():
 
    global alias
 
    for key, val in alias.iteritems():
 
        setattr(api, key, val)
 
alias_setup()
 

	
 

	
 
class PassiveOptionParser(OptionParser):
 

	
 
    def _process_args(self, largs, rargs, values):
 
        """little hack to support all --some_option=value parameters"""
 

	
 
        while rargs:
 
            arg = rargs[0]
 
            if arg == "--":
 
                del rargs[0]
 
                return
 
            elif arg[0:2] == "--":
 
                # if parser does not know about the option
 
                # pass it along (make it anonymous)
 
                try:
 
                    opt = arg.split('=', 1)[0]
 
                    self._match_long_opt(opt)
 
                except BadOptionError:
 
                    largs.append(arg)
 
                    del rargs[0]
 
                else:
 
                    self._process_long_opt(rargs, values)
 
            elif arg[:1] == "-" and len(arg) > 1:
 
                self._process_short_opts(rargs, values)
 
            elif self.allow_interspersed_args:
 
                largs.append(arg)
 
                del rargs[0]
 

	
 
def main(argv=None, **kwargs):
 
    """Shell interface to :mod:`migrate.versioning.api`.
 

	
 
    kwargs are default options that can be overriden with passing
 
    --some_option as command line option
 

	
 
    :param disable_logging: Let migrate configure logging
 
    :type disable_logging: bool
 
    """
 
    if argv is not None:
 
        argv = argv
 
    else:
 
        argv = list(sys.argv[1:])
 
    commands = list(api.__all__)
 
    commands.sort()
 

	
 
    usage = """%%prog COMMAND ...
 

	
 
    Available commands:
 
        %s
 

	
 
    Enter "%%prog help COMMAND" for information on a particular command.
 
    """ % '\n\t'.join(["%s - %s" % (command.ljust(28),
 
                    api.command_desc.get(command)) for command in commands])
 

	
 
    parser = PassiveOptionParser(usage=usage)
 
    parser.add_option("-d", "--debug",
 
                     action="store_true",
 
                     dest="debug",
 
                     default=False,
 
                     help="Shortcut to turn on DEBUG mode for logging")
 
    parser.add_option("-q", "--disable_logging",
 
                      action="store_true",
 
                      dest="disable_logging",
 
                      default=False,
 
                      help="Use this option to disable logging configuration")
 
    help_commands = ['help', '-h', '--help']
 
    HELP = False
 

	
 
    try:
 
        command = argv.pop(0)
 
        if command in help_commands:
 
            HELP = True
 
            command = argv.pop(0)
 
    except IndexError:
 
        parser.print_help()
 
        return
 

	
 
    command_func = getattr(api, command, None)
 
    if command_func is None or command.startswith('_'):
 
        parser.error("Invalid command %s" % command)
 

	
 
    parser.set_usage(inspect.getdoc(command_func))
 
    f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
 
    for arg in f_args:
 
        parser.add_option(
 
            "--%s" % arg,
 
            dest=arg,
 
            action='store',
 
            type="string")
 

	
 
    # display help of the current command
 
    if HELP:
 
        parser.print_help()
 
        return
 

	
 
    options, args = parser.parse_args(argv)
 

	
 
    # override kwargs with anonymous parameters
 
    override_kwargs = dict()
 
    for arg in list(args):
 
        if arg.startswith('--'):
 
            args.remove(arg)
 
            if '=' in arg:
 
                opt, value = arg[2:].split('=', 1)
 
            else:
 
                opt = arg[2:]
 
                value = True
 
            override_kwargs[opt] = value
 

	
 
    # override kwargs with options if user is overwriting
 
    for key, value in options.__dict__.iteritems():
 
        if value is not None:
 
            override_kwargs[key] = value
 

	
 
    # arguments that function accepts without passed kwargs
 
    f_required = list(f_args)
 
    candidates = dict(kwargs)
 
    candidates.update(override_kwargs)
 
    for key, value in candidates.iteritems():
 
        if key in f_args:
 
            f_required.remove(key)
 

	
 
    # map function arguments to parsed arguments
 
    for arg in args:
 
        try:
 
            kw = f_required.pop(0)
 
        except IndexError:
 
            parser.error("Too many arguments for command %s: %s" % (command,
 
                                                                    arg))
 
        kwargs[kw] = arg
 

	
 
    # apply overrides
 
    kwargs.update(override_kwargs)
 

	
 
    # configure options
 
    for key, value in options.__dict__.iteritems():
 
        kwargs.setdefault(key, value)
 

	
 
    # configure logging
 
    if not asbool(kwargs.pop('disable_logging', False)):
 
        # filter to log =< INFO into stdout and rest to stderr
 
        class SingleLevelFilter(logging.Filter):
 
            def __init__(self, min=None, max=None):
 
                self.min = min or 0
 
                self.max = max or 100
 

	
 
            def filter(self, record):
 
                return self.min <= record.levelno <= self.max
 

	
 
        logger = logging.getLogger()
 
        h1 = logging.StreamHandler(sys.stdout)
 
        f1 = SingleLevelFilter(max=logging.INFO)
 
        h1.addFilter(f1)
 
        h2 = logging.StreamHandler(sys.stderr)
 
        f2 = SingleLevelFilter(min=logging.WARN)
 
        h2.addFilter(f2)
 
        logger.addHandler(h1)
 
        logger.addHandler(h2)
 

	
 
        if options.debug:
 
            logger.setLevel(logging.DEBUG)
 
        else:
 
            logger.setLevel(logging.INFO)
 

	
 
    log = logging.getLogger(__name__)
 

	
 
    # check if all args are given
 
    try:
 
        num_defaults = len(f_defaults)
 
    except TypeError:
 
        num_defaults = 0
 
    f_args_default = f_args[len(f_args) - num_defaults:]
 
    required = list(set(f_required) - set(f_args_default))
 
    if required:
 
        parser.error("Not enough arguments for command %s: %s not specified" \
 
            % (command, ', '.join(required)))
 

	
 
    # handle command
 
    try:
 
        ret = command_func(**kwargs)
 
        if ret is not None:
 
            log.info(ret)
 
    except (exceptions.UsageError, exceptions.KnownError), e:
 
        parser.error(e.args[0])
 

	
 
if __name__ == "__main__":
 
    main()
rhodecode/lib/dbmigrate/migrate/versioning/template.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
import os
 
import shutil
 
import sys
 

	
 
from pkg_resources import resource_filename
 

	
 
from migrate.versioning.config import *
 
from migrate.versioning import pathed
 

	
 

	
 
class Collection(pathed.Pathed):
 
    """A collection of templates of a specific type"""
 
    _mask = None
 

	
 
    def get_path(self, file):
 
        return os.path.join(self.path, str(file))
 

	
 

	
 
class RepositoryCollection(Collection):
 
    _mask = '%s'
 

	
 
class ScriptCollection(Collection):
 
    _mask = '%s.py_tmpl'
 

	
 
class ManageCollection(Collection):
 
    _mask = '%s.py_tmpl'
 

	
 
class SQLScriptCollection(Collection):
 
    _mask = '%s.py_tmpl'
 

	
 
class Template(pathed.Pathed):
 
    """Finds the paths/packages of various Migrate templates.
 
    
 
    :param path: Templates are loaded from migrate package
 
    if `path` is not provided.
 
    """
 
    pkg = 'migrate.versioning.templates'
 
    _manage = 'manage.py_tmpl'
 

	
 
    def __new__(cls, path=None):
 
        if path is None:
 
            path = cls._find_path(cls.pkg)
 
        return super(Template, cls).__new__(cls, path)
 

	
 
    def __init__(self, path=None):
 
        if path is None:
 
            path = Template._find_path(self.pkg)
 
        super(Template, self).__init__(path)
 
        self.repository = RepositoryCollection(os.path.join(path, 'repository'))
 
        self.script = ScriptCollection(os.path.join(path, 'script'))
 
        self.manage = ManageCollection(os.path.join(path, 'manage'))
 
        self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
 

	
 
    @classmethod
 
    def _find_path(cls, pkg):
 
        """Returns absolute path to dotted python package."""
 
        tmp_pkg = pkg.rsplit('.', 1)
 

	
 
        if len(tmp_pkg) != 1:
 
            return resource_filename(tmp_pkg[0], tmp_pkg[1])
 
        else:
 
            return resource_filename(tmp_pkg[0], '')
 

	
 
    def _get_item(self, collection, theme=None):
 
        """Locates and returns collection.
 
        
 
        :param collection: name of collection to locate
 
        :param type_: type of subfolder in collection (defaults to "_default")
 
        :returns: (package, source)
 
        :rtype: str, str
 
        """
 
        item = getattr(self, collection)
 
        theme_mask = getattr(item, '_mask')
 
        theme = theme_mask % (theme or 'default')
 
        return item.get_path(theme)
 

	
 
    def get_repository(self, *a, **kw):
 
        """Calls self._get_item('repository', *a, **kw)"""
 
        return self._get_item('repository', *a, **kw)
 
    
 
    def get_script(self, *a, **kw):
 
        """Calls self._get_item('script', *a, **kw)"""
 
        return self._get_item('script', *a, **kw)
 

	
 
    def get_sql_script(self, *a, **kw):
 
        """Calls self._get_item('sql_script', *a, **kw)"""
 
        return self._get_item('sql_script', *a, **kw)
 

	
 
    def get_manage(self, *a, **kw):
 
        """Calls self._get_item('manage', *a, **kw)"""
 
        return self._get_item('manage', *a, **kw)
rhodecode/lib/dbmigrate/migrate/versioning/templates/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/manage.py_tmpl
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
from migrate.versioning.shell import main
 

	
 
if __name__ == '__main__':
 
    main(%(defaults)s)
rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/default.py_tmpl
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
from migrate.versioning.shell import main
 

	
 
{{py:
 
_vars = locals().copy()
 
del _vars['__template_name__']
 
_vars.pop('repository_name', None)
 
defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()])
 
}}
 
main({{ defaults }})
rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/pylons.py_tmpl
Show inline comments
 
new file 100644
 
#!/usr/bin/python
 
# -*- coding: utf-8 -*-
 
import sys
 

	
 
from sqlalchemy import engine_from_config
 
from paste.deploy.loadwsgi import ConfigLoader
 

	
 
from migrate.versioning.shell import main
 
from {{ locals().pop('repository_name') }}.model import migrations
 

	
 

	
 
if '-c' in sys.argv:
 
    pos = sys.argv.index('-c')
 
    conf_path = sys.argv[pos + 1]
 
    del sys.argv[pos:pos + 2]
 
else:
 
    conf_path = 'development.ini'
 

	
 
{{py:
 
_vars = locals().copy()
 
del _vars['__template_name__']
 
defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()])
 
}}
 

	
 
conf_dict = ConfigLoader(conf_path).parser._sections['app:main']
 

	
 
# migrate supports passing url as an existing Engine instance (since 0.6.0)
 
# usage: migrate -c path/to/config.ini COMMANDS
 
main(url=engine_from_config(conf_dict), repository=migrations.__path__[0],{{ defaults }})
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/README
Show inline comments
 
new file 100644
 
This is a database migration repository.
 

	
 
More information at
 
http://code.google.com/p/sqlalchemy-migrate/
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/migrate.cfg
Show inline comments
 
new file 100644
 
[db_settings]
 
# Used to identify which repository this database is versioned under.
 
# You can use the name of your project.
 
repository_id={{ locals().pop('repository_id') }}
 

	
 
# The name of the database table used to track the schema version.
 
# This name shouldn't already be used by your project.
 
# If this is changed once a database is under version control, you'll need to 
 
# change the table name in each database too. 
 
version_table={{ locals().pop('version_table') }}
 

	
 
# When committing a change script, Migrate will attempt to generate the 
 
# sql for all supported databases; normally, if one of them fails - probably
 
# because you don't have that database installed - it is ignored and the 
 
# commit continues, perhaps ending successfully. 
 
# Databases in this list MUST compile successfully during a commit, or the 
 
# entire commit will fail. List the databases your application will actually 
 
# be using to ensure your updates to that database work properly.
 
# This must be a list; example: ['postgres','sqlite']
 
required_dbs={{ locals().pop('required_dbs') }}
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/versions/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/README
Show inline comments
 
new file 100644
 
This is a database migration repository.
 

	
 
More information at
 
http://code.google.com/p/sqlalchemy-migrate/
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/migrate.cfg
Show inline comments
 
new file 100644
 
[db_settings]
 
# Used to identify which repository this database is versioned under.
 
# You can use the name of your project.
 
repository_id={{ locals().pop('repository_id') }}
 

	
 
# The name of the database table used to track the schema version.
 
# This name shouldn't already be used by your project.
 
# If this is changed once a database is under version control, you'll need to 
 
# change the table name in each database too. 
 
version_table={{ locals().pop('version_table') }}
 

	
 
# When committing a change script, Migrate will attempt to generate the 
 
# sql for all supported databases; normally, if one of them fails - probably
 
# because you don't have that database installed - it is ignored and the 
 
# commit continues, perhaps ending successfully. 
 
# Databases in this list MUST compile successfully during a commit, or the 
 
# entire commit will fail. List the databases your application will actually 
 
# be using to ensure your updates to that database work properly.
 
# This must be a list; example: ['postgres','sqlite']
 
required_dbs={{ locals().pop('required_dbs') }}
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/versions/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/script/__init__.py
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/script/default.py_tmpl
Show inline comments
 
new file 100644
 
from sqlalchemy import *
 
from migrate import *
 

	
 
def upgrade(migrate_engine):
 
    # Upgrade operations go here. Don't create your own engine; bind migrate_engine
 
    # to your metadata
 
    pass
 

	
 
def downgrade(migrate_engine):
 
    # Operations to reverse the above upgrade go here.
 
    pass
rhodecode/lib/dbmigrate/migrate/versioning/templates/script/pylons.py_tmpl
Show inline comments
 
new file 100644
 
from sqlalchemy import *
 
from migrate import *
 

	
 
def upgrade(migrate_engine):
 
    # Upgrade operations go here. Don't create your own engine; bind migrate_engine
 
    # to your metadata
 
    pass
 

	
 
def downgrade(migrate_engine):
 
    # Operations to reverse the above upgrade go here.
 
    pass
rhodecode/lib/dbmigrate/migrate/versioning/templates/sql_script/default.py_tmpl
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/templates/sql_script/pylons.py_tmpl
Show inline comments
 
new file 100644
rhodecode/lib/dbmigrate/migrate/versioning/util/__init__.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 
""".. currentmodule:: migrate.versioning.util"""
 

	
 
import warnings
 
import logging
 
from decorator import decorator
 
from pkg_resources import EntryPoint
 

	
 
from sqlalchemy import create_engine
 
from sqlalchemy.engine import Engine
 
from sqlalchemy.pool import StaticPool
 

	
 
from migrate import exceptions
 
from migrate.versioning.util.keyedinstance import KeyedInstance
 
from migrate.versioning.util.importpath import import_path
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
def load_model(dotted_name):
 
    """Import module and use module-level variable".
 

	
 
    :param dotted_name: path to model in form of string: ``some.python.module:Class``
 
    
 
    .. versionchanged:: 0.5.4
 

	
 
    """
 
    if isinstance(dotted_name, basestring):
 
        if ':' not in dotted_name:
 
            # backwards compatibility
 
            warnings.warn('model should be in form of module.model:User '
 
                'and not module.model.User', exceptions.MigrateDeprecationWarning)
 
            dotted_name = ':'.join(dotted_name.rsplit('.', 1))
 
        return EntryPoint.parse('x=%s' % dotted_name).load(False)
 
    else:
 
        # Assume it's already loaded.
 
        return dotted_name
 

	
 
def asbool(obj):
 
    """Do everything to use object as bool"""
 
    if isinstance(obj, basestring):
 
        obj = obj.strip().lower()
 
        if obj in ['true', 'yes', 'on', 'y', 't', '1']:
 
            return True
 
        elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
 
            return False
 
        else:
 
            raise ValueError("String is not true/false: %r" % obj)
 
    if obj in (True, False):
 
        return bool(obj)
 
    else:
 
        raise ValueError("String is not true/false: %r" % obj)
 

	
 
def guess_obj_type(obj):
 
    """Do everything to guess object type from string
 
    
 
    Tries to convert to `int`, `bool` and finally returns if not succeded.
 
    
 
    .. versionadded: 0.5.4
 
    """
 

	
 
    result = None
 

	
 
    try:
 
        result = int(obj)
 
    except:
 
        pass
 

	
 
    if result is None:
 
        try:
 
            result = asbool(obj)
 
        except:
 
            pass
 

	
 
    if result is not None:
 
        return result
 
    else:
 
        return obj
 

	
 
@decorator
 
def catch_known_errors(f, *a, **kw):
 
    """Decorator that catches known api errors
 
    
 
    .. versionadded: 0.5.4
 
    """
 

	
 
    try:
 
        return f(*a, **kw)
 
    except exceptions.PathFoundError, e:
 
        raise exceptions.KnownError("The path %s already exists" % e.args[0])
 

	
 
def construct_engine(engine, **opts):
 
    """.. versionadded:: 0.5.4
 

	
 
    Constructs and returns SQLAlchemy engine.
 

	
 
    Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
 

	
 
    :param engine: connection string or a existing engine
 
    :param engine_dict: python dictionary of options to pass to `create_engine`
 
    :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
 
    :type engine_dict: dict
 
    :type engine: string or Engine instance
 
    :type engine_arg_*: string
 
    :returns: SQLAlchemy Engine
 

	
 
    .. note::
 

	
 
        keyword parameters override ``engine_dict`` values.
 

	
 
    """
 
    if isinstance(engine, Engine):
 
        return engine
 
    elif not isinstance(engine, basestring):
 
        raise ValueError("you need to pass either an existing engine or a database uri")
 

	
 
    # get options for create_engine
 
    if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
 
        kwargs = opts['engine_dict']
 
    else:
 
        kwargs = dict()
 

	
 
    # DEPRECATED: handle echo the old way
 
    echo = asbool(opts.get('echo', False))
 
    if echo:
 
        warnings.warn('echo=True parameter is deprecated, pass '
 
            'engine_arg_echo=True or engine_dict={"echo": True}',
 
            exceptions.MigrateDeprecationWarning)
 
        kwargs['echo'] = echo
 

	
 
    # parse keyword arguments
 
    for key, value in opts.iteritems():
 
        if key.startswith('engine_arg_'):
 
            kwargs[key[11:]] = guess_obj_type(value)
 

	
 
    log.debug('Constructing engine')
 
    # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
 
    # seems like 0.5.x branch does not work with engine.dispose and staticpool
 
    return create_engine(engine, **kwargs)
 

	
 
@decorator
 
def with_engine(f, *a, **kw):
 
    """Decorator for :mod:`migrate.versioning.api` functions
 
    to safely close resources after function usage.
 

	
 
    Passes engine parameters to :func:`construct_engine` and
 
    resulting parameter is available as kw['engine'].
 

	
 
    Engine is disposed after wrapped function is executed.
 

	
 
    .. versionadded: 0.6.0
 
    """
 
    url = a[0]
 
    engine = construct_engine(url, **kw)
 

	
 
    try:
 
        kw['engine'] = engine
 
        return f(*a, **kw)
 
    finally:
 
        if isinstance(engine, Engine):
 
            log.debug('Disposing SQLAlchemy engine %s', engine)
 
            engine.dispose()
 

	
 

	
 
class Memoize:
 
    """Memoize(fn) - an instance which acts like fn but memoizes its arguments
 
       Will only work on functions with non-mutable arguments
 

	
 
       ActiveState Code 52201
 
    """
 
    def __init__(self, fn):
 
        self.fn = fn
 
        self.memo = {}
 

	
 
    def __call__(self, *args):
 
        if not self.memo.has_key(args):
 
            self.memo[args] = self.fn(*args)
 
        return self.memo[args]
rhodecode/lib/dbmigrate/migrate/versioning/util/importpath.py
Show inline comments
 
new file 100644
 
import os
 
import sys
 

	
 
def import_path(fullpath):
 
    """ Import a file with full path specification. Allows one to
 
        import from anywhere, something __import__ does not do. 
 
    """
 
    # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html
 
    path, filename = os.path.split(fullpath)
 
    filename, ext = os.path.splitext(filename)
 
    sys.path.append(path)
 
    module = __import__(filename)
 
    reload(module) # Might be out of date during tests
 
    del sys.path[-1]
 
    return module
 

	
rhodecode/lib/dbmigrate/migrate/versioning/util/keyedinstance.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
class KeyedInstance(object):
 
    """A class whose instances have a unique identifier of some sort
 
    No two instances with the same unique ID should exist - if we try to create
 
    a second instance, the first should be returned. 
 
    """
 

	
 
    _instances = dict()
 

	
 
    def __new__(cls, *p, **k):
 
        instances = cls._instances
 
        clskey = str(cls)
 
        if clskey not in instances:
 
            instances[clskey] = dict()
 
        instances = instances[clskey]
 

	
 
        key = cls._key(*p, **k)
 
        if key not in instances:
 
            instances[key] = super(KeyedInstance, cls).__new__(cls)
 
        return instances[key]
 

	
 
    @classmethod
 
    def _key(cls, *p, **k):
 
        """Given a unique identifier, return a dictionary key
 
        This should be overridden by child classes, to specify which parameters 
 
        should determine an object's uniqueness
 
        """
 
        raise NotImplementedError()
 

	
 
    @classmethod
 
    def clear(cls):
 
        # Allow cls.clear() as well as uniqueInstance.clear(cls)
 
        if str(cls) in cls._instances:
 
            del cls._instances[str(cls)]
rhodecode/lib/dbmigrate/migrate/versioning/version.py
Show inline comments
 
new file 100644
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
import os
 
import re
 
import shutil
 
import logging
 

	
 
from migrate import exceptions
 
from migrate.versioning import pathed, script
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class VerNum(object):
 
    """A version number that behaves like a string and int at the same time"""
 

	
 
    _instances = dict()
 

	
 
    def __new__(cls, value):
 
        val = str(value)
 
        if val not in cls._instances:
 
            cls._instances[val] = super(VerNum, cls).__new__(cls)
 
        ret = cls._instances[val]
 
        return ret
 

	
 
    def __init__(self,value):
 
        self.value = str(int(value))
 
        if self < 0:
 
            raise ValueError("Version number cannot be negative")
 

	
 
    def __add__(self, value):
 
        ret = int(self) + int(value)
 
        return VerNum(ret)
 

	
 
    def __sub__(self, value):
 
        return self + (int(value) * -1)
 

	
 
    def __cmp__(self, value):
 
        return int(self) - int(value)
 

	
 
    def __repr__(self):
 
        return "<VerNum(%s)>" % self.value
 

	
 
    def __str__(self):
 
        return str(self.value)
 

	
 
    def __int__(self):
 
        return int(self.value)
 

	
 

	
 
class Collection(pathed.Pathed):
 
    """A collection of versioning scripts in a repository"""
 

	
 
    FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
 

	
 
    def __init__(self, path):
 
        """Collect current version scripts in repository
 
        and store them in self.versions
 
        """
 
        super(Collection, self).__init__(path)
 
        
 
        # Create temporary list of files, allowing skipped version numbers.
 
        files = os.listdir(path)
 
        if '1' in files:
 
            # deprecation
 
            raise Exception('It looks like you have a repository in the old '
 
                'format (with directories for each version). '
 
                'Please convert repository before proceeding.')
 

	
 
        tempVersions = dict()
 
        for filename in files:
 
            match = self.FILENAME_WITH_VERSION.match(filename)
 
            if match:
 
                num = int(match.group(1))
 
                tempVersions.setdefault(num, []).append(filename)
 
            else:
 
                pass  # Must be a helper file or something, let's ignore it.
 

	
 
        # Create the versions member where the keys
 
        # are VerNum's and the values are Version's.
 
        self.versions = dict()
 
        for num, files in tempVersions.items():
 
            self.versions[VerNum(num)] = Version(num, path, files)
 

	
 
    @property
 
    def latest(self):
 
        """:returns: Latest version in Collection"""
 
        return max([VerNum(0)] + self.versions.keys())
 

	
 
    def create_new_python_version(self, description, **k):
 
        """Create Python files for new version"""
 
        ver = self.latest + 1
 
        extra = str_to_filename(description)
 

	
 
        if extra:
 
            if extra == '_':
 
                extra = ''
 
            elif not extra.startswith('_'):
 
                extra = '_%s' % extra
 

	
 
        filename = '%03d%s.py' % (ver, extra)
 
        filepath = self._version_path(filename)
 

	
 
        script.PythonScript.create(filepath, **k)
 
        self.versions[ver] = Version(ver, self.path, [filename])
 
        
 
    def create_new_sql_version(self, database, **k):
 
        """Create SQL files for new version"""
 
        ver = self.latest + 1
 
        self.versions[ver] = Version(ver, self.path, [])
 

	
 
        # Create new files.
 
        for op in ('upgrade', 'downgrade'):
 
            filename = '%03d_%s_%s.sql' % (ver, database, op)
 
            filepath = self._version_path(filename)
 
            script.SqlScript.create(filepath, **k)
 
            self.versions[ver].add_script(filepath)
 
        
 
    def version(self, vernum=None):
 
        """Returns latest Version if vernum is not given.
 
        Otherwise, returns wanted version"""
 
        if vernum is None:
 
            vernum = self.latest
 
        return self.versions[VerNum(vernum)]
 

	
 
    @classmethod
 
    def clear(cls):
 
        super(Collection, cls).clear()
 

	
 
    def _version_path(self, ver):
 
        """Returns path of file in versions repository"""
 
        return os.path.join(self.path, str(ver))
 

	
 

	
 
class Version(object):
 
    """A single version in a collection
 
    :param vernum: Version Number 
 
    :param path: Path to script files
 
    :param filelist: List of scripts
 
    :type vernum: int, VerNum
 
    :type path: string
 
    :type filelist: list
 
    """
 

	
 
    def __init__(self, vernum, path, filelist):
 
        self.version = VerNum(vernum)
 

	
 
        # Collect scripts in this folder
 
        self.sql = dict()
 
        self.python = None
 

	
 
        for script in filelist:
 
            self.add_script(os.path.join(path, script))
 
    
 
    def script(self, database=None, operation=None):
 
        """Returns SQL or Python Script"""
 
        for db in (database, 'default'):
 
            # Try to return a .sql script first
 
            try:
 
                return self.sql[db][operation]
 
            except KeyError:
 
                continue  # No .sql script exists
 

	
 
        # TODO: maybe add force Python parameter?
 
        ret = self.python
 

	
 
        assert ret is not None, \
 
            "There is no script for %d version" % self.version
 
        return ret
 

	
 
    def add_script(self, path):
 
        """Add script to Collection/Version"""
 
        if path.endswith(Extensions.py):
 
            self._add_script_py(path)
 
        elif path.endswith(Extensions.sql):
 
            self._add_script_sql(path)
 

	
 
    SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
 

	
 
    def _add_script_sql(self, path):
 
        basename = os.path.basename(path)
 
        match = self.SQL_FILENAME.match(basename)
 

	
 
        if match:
 
            version, dbms, op = match.group(1), match.group(2), match.group(3)
 
        else:
 
            raise exceptions.ScriptError(
 
                "Invalid SQL script name %s " % basename + \
 
                "(needs to be ###_database_operation.sql)")
 

	
 
        # File the script into a dictionary
 
        self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
 

	
 
    def _add_script_py(self, path):
 
        if self.python is not None:
 
            raise exceptions.ScriptError('You can only have one Python script '
 
                'per version, but you have: %s and %s' % (self.python, path))
 
        self.python = script.PythonScript(path)
 

	
 

	
 
class Extensions:
 
    """A namespace for file extensions"""
 
    py = 'py'
 
    sql = 'sql'
 

	
 
def str_to_filename(s):
 
    """Replaces spaces, (double and single) quotes
 
    and double underscores to underscores
 
    """
 

	
 
    s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
 
    while '__' in s:
 
        s = s.replace('__', '_')
 
    return s
rhodecode/lib/dbmigrate/versions/001_initial_release.py
Show inline comments
 
new file 100644
 
from migrate import *
 

	
 
#==============================================================================
 
# DB INITIAL MODEL
 
#==============================================================================
 
import logging
 
import datetime
 

	
 
from sqlalchemy import *
 
from sqlalchemy.exc import DatabaseError
 
from sqlalchemy.orm import relation, backref, class_mapper
 
from sqlalchemy.orm.session import Session
 

	
 
from rhodecode.model.meta import Base
 

	
 
log = logging.getLogger(__name__)
 

	
 
class BaseModel(object):
 

	
 
    @classmethod
 
    def _get_keys(cls):
 
        """return column names for this model """
 
        return class_mapper(cls).c.keys()
 

	
 
    def get_dict(self):
 
        """return dict with keys and values corresponding 
 
        to this model data """
 

	
 
        d = {}
 
        for k in self._get_keys():
 
            d[k] = getattr(self, k)
 
        return d
 

	
 
    def get_appstruct(self):
 
        """return list with keys and values tupples corresponding 
 
        to this model data """
 

	
 
        l = []
 
        for k in self._get_keys():
 
            l.append((k, getattr(self, k),))
 
        return l
 

	
 
    def populate_obj(self, populate_dict):
 
        """populate model with data from given populate_dict"""
 

	
 
        for k in self._get_keys():
 
            if k in populate_dict:
 
                setattr(self, k, populate_dict[k])
 

	
 
class RhodeCodeSettings(Base, BaseModel):
 
    __tablename__ = 'rhodecode_settings'
 
    __table_args__ = (UniqueConstraint('app_settings_name'), {'useexisting':True})
 
    app_settings_id = Column("app_settings_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    app_settings_name = Column("app_settings_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    app_settings_value = Column("app_settings_value", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 

	
 
    def __init__(self, k, v):
 
        self.app_settings_name = k
 
        self.app_settings_value = v
 

	
 
    def __repr__(self):
 
        return "<RhodeCodeSetting('%s:%s')>" % (self.app_settings_name,
 
                                                self.app_settings_value)
 

	
 
class RhodeCodeUi(Base, BaseModel):
 
    __tablename__ = 'rhodecode_ui'
 
    __table_args__ = {'useexisting':True}
 
    ui_id = Column("ui_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    ui_section = Column("ui_section", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    ui_key = Column("ui_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    ui_value = Column("ui_value", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    ui_active = Column("ui_active", Boolean(), nullable=True, unique=None, default=True)
 

	
 

	
 
class User(Base, BaseModel):
 
    __tablename__ = 'users'
 
    __table_args__ = (UniqueConstraint('username'), UniqueConstraint('email'), {'useexisting':True})
 
    user_id = Column("user_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    username = Column("username", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    password = Column("password", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    active = Column("active", Boolean(), nullable=True, unique=None, default=None)
 
    admin = Column("admin", Boolean(), nullable=True, unique=None, default=False)
 
    name = Column("name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    lastname = Column("lastname", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    email = Column("email", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    last_login = Column("last_login", DateTime(timezone=False), nullable=True, unique=None, default=None)
 
    is_ldap = Column("is_ldap", Boolean(), nullable=False, unique=None, default=False)
 

	
 
    user_log = relation('UserLog', cascade='all')
 
    user_perms = relation('UserToPerm', primaryjoin="User.user_id==UserToPerm.user_id", cascade='all')
 

	
 
    repositories = relation('Repository')
 
    user_followers = relation('UserFollowing', primaryjoin='UserFollowing.follows_user_id==User.user_id', cascade='all')
 

	
 
    @property
 
    def full_contact(self):
 
        return '%s %s <%s>' % (self.name, self.lastname, self.email)
 

	
 
    def __repr__(self):
 
        return "<User('id:%s:%s')>" % (self.user_id, self.username)
 

	
 
    def update_lastlogin(self):
 
        """Update user lastlogin"""
 

	
 
        try:
 
            session = Session.object_session(self)
 
            self.last_login = datetime.datetime.now()
 
            session.add(self)
 
            session.commit()
 
            log.debug('updated user %s lastlogin', self.username)
 
        except (DatabaseError,):
 
            session.rollback()
 

	
 

	
 
class UserLog(Base, BaseModel):
 
    __tablename__ = 'user_logs'
 
    __table_args__ = {'useexisting':True}
 
    user_log_id = Column("user_log_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
 
    repository_id = Column("repository_id", Integer(length=None, convert_unicode=False, assert_unicode=None), ForeignKey(u'repositories.repo_id'), nullable=False, unique=None, default=None)
 
    repository_name = Column("repository_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    user_ip = Column("user_ip", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    action = Column("action", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    action_date = Column("action_date", DateTime(timezone=False), nullable=True, unique=None, default=None)
 

	
 
    user = relation('User')
 
    repository = relation('Repository')
 

	
 
class Repository(Base, BaseModel):
 
    __tablename__ = 'repositories'
 
    __table_args__ = (UniqueConstraint('repo_name'), {'useexisting':True},)
 
    repo_id = Column("repo_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    repo_name = Column("repo_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=False, unique=True, default=None)
 
    repo_type = Column("repo_type", String(length=None, convert_unicode=False, assert_unicode=None), nullable=False, unique=False, default=None)
 
    user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=False, default=None)
 
    private = Column("private", Boolean(), nullable=True, unique=None, default=None)
 
    enable_statistics = Column("statistics", Boolean(), nullable=True, unique=None, default=True)
 
    description = Column("description", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    fork_id = Column("fork_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=False, default=None)
 

	
 
    user = relation('User')
 
    fork = relation('Repository', remote_side=repo_id)
 
    repo_to_perm = relation('RepoToPerm', cascade='all')
 
    stats = relation('Statistics', cascade='all', uselist=False)
 

	
 
    repo_followers = relation('UserFollowing', primaryjoin='UserFollowing.follows_repo_id==Repository.repo_id', cascade='all')
 

	
 

	
 
    def __repr__(self):
 
        return "<Repository('%s:%s')>" % (self.repo_id, self.repo_name)
 

	
 
class Permission(Base, BaseModel):
 
    __tablename__ = 'permissions'
 
    __table_args__ = {'useexisting':True}
 
    permission_id = Column("permission_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    permission_name = Column("permission_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    permission_longname = Column("permission_longname", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 

	
 
    def __repr__(self):
 
        return "<Permission('%s:%s')>" % (self.permission_id, self.permission_name)
 

	
 
class RepoToPerm(Base, BaseModel):
 
    __tablename__ = 'repo_to_perm'
 
    __table_args__ = (UniqueConstraint('user_id', 'repository_id'), {'useexisting':True})
 
    repo_to_perm_id = Column("repo_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey(u'permissions.permission_id'), nullable=False, unique=None, default=None)
 
    repository_id = Column("repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relation('User')
 
    permission = relation('Permission')
 
    repository = relation('Repository')
 

	
 
class UserToPerm(Base, BaseModel):
 
    __tablename__ = 'user_to_perm'
 
    __table_args__ = (UniqueConstraint('user_id', 'permission_id'), {'useexisting':True})
 
    user_to_perm_id = Column("user_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey(u'permissions.permission_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relation('User')
 
    permission = relation('Permission')
 

	
 
class Statistics(Base, BaseModel):
 
    __tablename__ = 'statistics'
 
    __table_args__ = (UniqueConstraint('repository_id'), {'useexisting':True})
 
    stat_id = Column("stat_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    repository_id = Column("repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=False, unique=True, default=None)
 
    stat_on_revision = Column("stat_on_revision", Integer(), nullable=False)
 
    commit_activity = Column("commit_activity", LargeBinary(), nullable=False)#JSON data
 
    commit_activity_combined = Column("commit_activity_combined", LargeBinary(), nullable=False)#JSON data
 
    languages = Column("languages", LargeBinary(), nullable=False)#JSON data
 

	
 
    repository = relation('Repository', single_parent=True)
 

	
 
class UserFollowing(Base, BaseModel):
 
    __tablename__ = 'user_followings'
 
    __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'),
 
                      UniqueConstraint('user_id', 'follows_user_id')
 
                      , {'useexisting':True})
 

	
 
    user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
 
    follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=None, default=None)
 
    follows_user_id = Column("follows_user_id", Integer(), ForeignKey(u'users.user_id'), nullable=True, unique=None, default=None)
 

	
 
    user = relation('User', primaryjoin='User.user_id==UserFollowing.user_id')
 

	
 
    follows_user = relation('User', primaryjoin='User.user_id==UserFollowing.follows_user_id')
 
    follows_repository = relation('Repository')
 

	
 

	
 
class CacheInvalidation(Base, BaseModel):
 
    __tablename__ = 'cache_invalidation'
 
    __table_args__ = (UniqueConstraint('cache_key'), {'useexisting':True})
 
    cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    cache_key = Column("cache_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    cache_args = Column("cache_args", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False)
 

	
 

	
 
    def __init__(self, cache_key, cache_args=''):
 
        self.cache_key = cache_key
 
        self.cache_args = cache_args
 
        self.cache_active = False
 

	
 
    def __repr__(self):
 
        return "<CacheInvalidation('%s:%s')>" % (self.cache_id, self.cache_key)
 

	
 

	
 
def upgrade(migrate_engine):
 
    # Upgrade operations go here. Don't create your own engine; bind migrate_engine
 
    # to your metadata
 
    Base.metadata.create_all(bind=migrate_engine, checkfirst=False)
 

	
 
def downgrade(migrate_engine):
 
    # Operations to reverse the above upgrade go here.
 
    Base.metadata.drop_all(bind=migrate_engine, checkfirst=False)
rhodecode/lib/dbmigrate/versions/002_version_1_1_0.py
Show inline comments
 
new file 100644
 
from sqlalchemy import *
 
from sqlalchemy.orm import relation
 

	
 
from migrate import *
 
from migrate.changeset import *
 
from rhodecode.model.meta import Base, BaseModel
 

	
 
def upgrade(migrate_engine):
 
    """ Upgrade operations go here. 
 
    Don't create your own engine; bind migrate_engine to your metadata
 
    """
 

	
 
    #==========================================================================
 
    # Upgrade of `users` table
 
    #==========================================================================
 
    tblname = 'users'
 
    tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True,
 
                    autoload_with=migrate_engine)
 

	
 
    #ADD is_ldap column
 
    is_ldap = Column("is_ldap", Boolean(), nullable=False,
 
                     unique=None, default=False)
 
    is_ldap.create(tbl)
 

	
 

	
 
    #==========================================================================
 
    # Upgrade of `user_logs` table
 
    #==========================================================================    
 

	
 
    tblname = 'users'
 
    tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True,
 
                    autoload_with=migrate_engine)
 

	
 
    #ADD revision column
 
    revision = Column('revision', TEXT(length=None, convert_unicode=False,
 
                                       assert_unicode=None),
 
                      nullable=True, unique=None, default=None)
 
    revision.create(tbl)
 

	
 

	
 

	
 
    #==========================================================================
 
    # Upgrade of `repositories` table
 
    #==========================================================================    
 
    tblname = 'users'
 
    tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True,
 
                    autoload_with=migrate_engine)
 

	
 
    #ADD repo_type column
 
    repo_type = Column("repo_type", String(length=None, convert_unicode=False,
 
                                           assert_unicode=None),
 
                       nullable=False, unique=False, default=None)
 
    repo_type.create(tbl)
 

	
 

	
 
    #ADD statistics column
 
    enable_statistics = Column("statistics", Boolean(), nullable=True,
 
                               unique=None, default=True)
 
    enable_statistics.create(tbl)
 

	
 

	
 

	
 
    #==========================================================================
 
    # Add table `user_followings`
 
    #==========================================================================
 
    tblname = 'user_followings'
 
    class UserFollowing(Base, BaseModel):
 
        __tablename__ = 'user_followings'
 
        __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'),
 
                          UniqueConstraint('user_id', 'follows_user_id')
 
                          , {'useexisting':True})
 

	
 
        user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
        user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
 
        follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=None, default=None)
 
        follows_user_id = Column("follows_user_id", Integer(), ForeignKey(u'users.user_id'), nullable=True, unique=None, default=None)
 

	
 
        user = relation('User', primaryjoin='User.user_id==UserFollowing.user_id')
 

	
 
        follows_user = relation('User', primaryjoin='User.user_id==UserFollowing.follows_user_id')
 
        follows_repository = relation('Repository')
 

	
 
    Base.metadata.tables[tblname].create(migrate_engine)
 

	
 
    #==========================================================================
 
    # Add table `cache_invalidation`
 
    #==========================================================================
 
    class CacheInvalidation(Base, BaseModel):
 
        __tablename__ = 'cache_invalidation'
 
        __table_args__ = (UniqueConstraint('cache_key'), {'useexisting':True})
 
        cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
        cache_key = Column("cache_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
        cache_args = Column("cache_args", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
        cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False)
 

	
 

	
 
        def __init__(self, cache_key, cache_args=''):
 
            self.cache_key = cache_key
 
            self.cache_args = cache_args
 
            self.cache_active = False
 

	
 
        def __repr__(self):
 
            return "<CacheInvalidation('%s:%s')>" % (self.cache_id, self.cache_key)
 

	
 
    Base.metadata.tables[tblname].create(migrate_engine)
 

	
 
    return
 

	
 

	
 

	
 

	
 

	
 

	
 
def downgrade(migrate_engine):
 
    meta = MetaData()
 
    meta.bind = migrate_engine
 

	
 

	
rhodecode/lib/dbmigrate/versions/__init__.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
"""
 
    rhodecode.lib.dbmigrate.versions.__init__
 
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
    Package containing new versions of database models
 
    
 
    :created_on: Dec 11, 2010
 
    :author: marcink
 
    :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>    
 
    :license: GPLv3, see COPYING for more details.
 
"""
 
# This program is free software; you can redistribute it and/or
 
# modify it under the terms of the GNU General Public License
 
# as published by the Free Software Foundation; version 2
 
# of the License or (at your opinion) any later version of the license.
 
# 
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
# 
 
# You should have received a copy of the GNU General Public License
 
# along with this program; if not, write to the Free Software
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 
# MA  02110-1301, USA.
rhodecode/lib/utils.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
"""
 
    package.rhodecode.lib.utils
 
    ~~~~~~~~~~~~~~
 
    rhodecode.lib.utils
 
    ~~~~~~~~~~~~~~~~~~~
 

	
 
    Utilities library for RhodeCode
 
    
 
@@ -599,30 +599,3 @@ class BasePasterCommand(Command):
 
        path_to_ini_file = os.path.realpath(conf)
 
        conf = paste.deploy.appconfig('config:' + path_to_ini_file)
 
        pylonsconfig.init_app(conf.global_conf, conf.local_conf)
 

	
 

	
 

	
 
class UpgradeDb(BasePasterCommand):
 
    """Command used for paster to upgrade our database to newer version
 
    """
 

	
 
    max_args = 1
 
    min_args = 1
 

	
 
    usage = "CONFIG_FILE"
 
    summary = "Upgrades current db to newer version given configuration file"
 
    group_name = "RhodeCode"
 

	
 
    parser = Command.standard_parser(verbose=True)
 

	
 
    def command(self):
 
        from pylons import config
 
        raise NotImplementedError('Not implemented yet')
 

	
 

	
 
    def update_parser(self):
 
        self.parser.add_option('--sql',
 
                      action='store_true',
 
                      dest='just_sql',
 
                      help="Prints upgrade sql for further investigation",
 
                      default=False)
setup.py
Show inline comments
 
@@ -93,7 +93,7 @@ setup(
 

	
 
    [paste.global_paster_command]
 
    make-index = rhodecode.lib.indexers:MakeIndex
 
    upgrade-db = rhodecode.lib.utils:UpgradeDb
 
    upgrade-db = rhodecode.lib.dbmigrate:UpgradeDb
 
    celeryd=rhodecode.lib.celerypylons.commands:CeleryDaemonCommand
 
    celerybeat=rhodecode.lib.celerypylons.commands:CeleryBeatCommand
 
    camqadm=rhodecode.lib.celerypylons.commands:CAMQPAdminCommand
0 comments (0 inline, 0 general)