Changeset - 7f31de1584c6
rhodecode/lib/dbmigrate/migrate/__init__.py
Show inline comments
 
"""
 
   SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
 
   database schema version and repository management and
 
   :mod:`migrate.changeset` that allows to define database schema changes
 
   using Python.
 
"""
 

	
 
from rhodecode.lib.dbmigrate.migrate.versioning import *
 
from rhodecode.lib.dbmigrate.migrate.changeset import *
 

	
 
__version__ = '0.7.2.dev'
 
\ No newline at end of file
rhodecode/lib/dbmigrate/migrate/changeset/__init__.py
Show inline comments
 
"""
 
   This module extends SQLAlchemy and provides additional DDL [#]_
 
   support.
 

	
 
   .. [#] SQL Data Definition Language
 
"""
 
import re
 
import warnings
 

	
 
import sqlalchemy
 
from sqlalchemy import __version__ as _sa_version
 

	
 
warnings.simplefilter('always', DeprecationWarning)
 

	
 
_sa_version = tuple(int(re.match("\d+", x).group(0))
 
                    for x in _sa_version.split("."))
 
SQLA_06 = _sa_version >= (0, 6)
 
SQLA_07 = _sa_version >= (0, 7)
 

	
 
del re
 
del _sa_version
 

	
 
from rhodecode.lib.dbmigrate.migrate.changeset.schema import *
 
from rhodecode.lib.dbmigrate.migrate.changeset.constraint import *
 

	
 
sqlalchemy.schema.Table.__bases__ += (ChangesetTable,)
 
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn,)
 
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex,)
 

	
 
sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause,)
rhodecode/lib/dbmigrate/migrate/changeset/schema.py
Show inline comments
 
"""
 
   Schema module providing common schema operations.
 
"""
 
import warnings
 

	
 
from UserDict import DictMixin
 

	
 
import sqlalchemy
 

	
 
from sqlalchemy.schema import ForeignKeyConstraint
 
from sqlalchemy.schema import UniqueConstraint
 

	
 
from rhodecode.lib.dbmigrate.migrate.exceptions import *
 
from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
 
from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06, SQLA_07
 
from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
 
                                                 run_single_visitor)
 

	
 

	
 
__all__ = [
 
    'create_column',
 
    'drop_column',
 
    'alter_column',
 
    'rename_table',
 
    'rename_index',
 
    'ChangesetTable',
 
    'ChangesetColumn',
 
    'ChangesetIndex',
 
    'ChangesetDefaultClause',
 
    'ColumnDelta',
 
]
 

	
 
def create_column(column, table=None, *p, **kw):
 
    """Create a column, given the table.
 

	
 
    API to :meth:`ChangesetColumn.create`.
 
    """
 
    if table is not None:
 
        return table.create_column(column, *p, **kw)
 
    return column.create(*p, **kw)
 

	
 

	
 
def drop_column(column, table=None, *p, **kw):
 
    """Drop a column, given the table.
 

	
 
    API to :meth:`ChangesetColumn.drop`.
 
    """
 
    if table is not None:
 
        return table.drop_column(column, *p, **kw)
 
    return column.drop(*p, **kw)
 

	
 

	
 
def rename_table(table, name, engine=None, **kw):
 
    """Rename a table.
 

	
 
    If Table instance is given, engine is not used.
 

	
 
    API to :meth:`ChangesetTable.rename`.
 

	
 
    :param table: Table to be renamed.
 
    :param name: New name for Table.
 
    :param engine: Engine instance.
 
    :type table: string or Table instance
 
    :type name: string
 
    :type engine: obj
 
    """
 
    table = _to_table(table, engine)
 
    table.rename(name, **kw)
 

	
 

	
 
def rename_index(index, name, table=None, engine=None, **kw):
 
    """Rename an index.
 

	
 
    If Index instance is given,
 
    table and engine are not used.
 

	
 
    API to :meth:`ChangesetIndex.rename`.
 

	
 
    :param index: Index to be renamed.
 
    :param name: New name for index.
 
    :param table: Table to which Index is reffered.
 
    :param engine: Engine instance.
 
    :type index: string or Index instance
 
    :type name: string
 
    :type table: string or Table instance
 
    :type engine: obj
 
    """
 
    index = _to_index(index, table, engine)
 
    index.rename(name, **kw)
 

	
 

	
 
def alter_column(*p, **k):
 
    """Alter a column.
 

	
 
    This is a helper function that creates a :class:`ColumnDelta` and
 
    runs it.
 

	
 
    :argument column:
 
      The name of the column to be altered or a
 
      :class:`ChangesetColumn` column representing it.
 

	
 
    :param table:
 
      A :class:`~sqlalchemy.schema.Table` or table name to
 
      for the table where the column will be changed.
 

	
 
    :param engine:
 
      The :class:`~sqlalchemy.engine.base.Engine` to use for table
 
      reflection and schema alterations.
 

	
 
    :returns: A :class:`ColumnDelta` instance representing the change.
 

	
 

	
 
    """
 

	
 
    if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
 
        k['table'] = p[0].table
 
    if 'engine' not in k:
 
        k['engine'] = k['table'].bind
 

	
 
    # deprecation
 
    if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
 
        warnings.warn(
 
            "Passing a Column object to alter_column is deprecated."
 
            " Just pass in keyword parameters instead.",
 
            MigrateDeprecationWarning
 
            )
 
    engine = k['engine']
 

	
 
    # enough tests seem to break when metadata is always altered
 
    # that this crutch has to be left in until they can be sorted
 
    # out
 
    k['alter_metadata']=True
 

	
 
    delta = ColumnDelta(*p, **k)
 

	
 
    visitorcallable = get_engine_visitor(engine, 'schemachanger')
 
    engine._run_visitor(visitorcallable, delta)
 

	
 
    return delta
 

	
 

	
 
def _to_table(table, engine=None):
 
    """Return if instance of Table, else construct new with metadata"""
 
    if isinstance(table, sqlalchemy.Table):
 
        return table
 

	
 
    # Given: table name, maybe an engine
 
    meta = sqlalchemy.MetaData()
 
    if engine is not None:
 
        meta.bind = engine
 
    return sqlalchemy.Table(table, meta)
 

	
 

	
 
def _to_index(index, table=None, engine=None):
 
    """Return if instance of Index, else construct new with metadata"""
 
    if isinstance(index, sqlalchemy.Index):
 
        return index
 

	
 
    # Given: index name; table name required
 
    table = _to_table(table, engine)
 
    ret = sqlalchemy.Index(index)
 
    ret.table = table
 
    return ret
 

	
 

	
 
class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
 
    """Extracts the differences between two columns/column-parameters
 

	
 
        May receive parameters arranged in several different ways:
 

	
 
        * **current_column, new_column, \*p, \*\*kw**
 
            Additional parameters can be specified to override column
 
            differences.
 

	
 
        * **current_column, \*p, \*\*kw**
 
            Additional parameters alter current_column. Table name is extracted
 
            from current_column object.
 
            Name is changed to current_column.name from current_name,
 
            if current_name is specified.
 

	
 
        * **current_col_name, \*p, \*\*kw**
 
            Table kw must specified.
 

	
 
        :param table: Table at which current Column should be bound to.\
 
        If table name is given, reflection will be used.
 
        :type table: string or Table instance
 

	
 
        :param metadata: A :class:`MetaData` instance to store
 
                         reflected table names
 

	
 
        :param engine: When reflecting tables, either engine or metadata must \
 
        be specified to acquire engine object.
 
        :type engine: :class:`Engine` instance
 
        :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
 
        `result_column` through :func:`dict` alike object.
 

	
 
        * :class:`ColumnDelta`.result_column is altered column with new attributes
 

	
 
        * :class:`ColumnDelta`.current_name is current name of column in db
 

	
 

	
 
    """
 

	
 
    # Column attributes that can be altered
 
    diff_keys = ('name', 'type', 'primary_key', 'nullable',
 
        'server_onupdate', 'server_default', 'autoincrement')
 
    diffs = dict()
 
@@ -366,286 +366,292 @@ class ColumnDelta(DictMixin, sqlalchemy.
 
                toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
 
                                            for_update=True))
 
        if toinit:
 
            column._init_items(*toinit)
 

	
 
        if not SQLA_06:
 
            column.args = []
 

	
 
    def _get_table(self):
 
        return getattr(self, '_table', None)
 

	
 
    def _set_table(self, table):
 
        if isinstance(table, basestring):
 
            if self.alter_metadata:
 
                if not self.meta:
 
                    raise ValueError("metadata must be specified for table"
 
                        " reflection when using alter_metadata")
 
                meta = self.meta
 
                if self.engine:
 
                    meta.bind = self.engine
 
            else:
 
                if not self.engine and not self.meta:
 
                    raise ValueError("engine or metadata must be specified"
 
                        " to reflect tables")
 
                if not self.engine:
 
                    self.engine = self.meta.bind
 
                meta = sqlalchemy.MetaData(bind=self.engine)
 
            self._table = sqlalchemy.Table(table, meta, autoload=True)
 
        elif isinstance(table, sqlalchemy.Table):
 
            self._table = table
 
            if not self.alter_metadata:
 
                self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
 
    def _get_result_column(self):
 
        return getattr(self, '_result_column', None)
 

	
 
    def _set_result_column(self, column):
 
        """Set Column to Table based on alter_metadata evaluation."""
 
        self.process_column(column)
 
        if not hasattr(self, 'current_name'):
 
            self.current_name = column.name
 
        if self.alter_metadata:
 
            self._result_column = column
 
        else:
 
            self._result_column = column.copy_fixed()
 

	
 
    table = property(_get_table, _set_table)
 
    result_column = property(_get_result_column, _set_result_column)
 

	
 

	
 
class ChangesetTable(object):
 
    """Changeset extensions to SQLAlchemy tables."""
 

	
 
    def create_column(self, column, *p, **kw):
 
        """Creates a column.
 

	
 
        The column parameter may be a column definition or the name of
 
        a column in this table.
 

	
 
        API to :meth:`ChangesetColumn.create`
 

	
 
        :param column: Column to be created
 
        :type column: Column instance or string
 
        """
 
        if not isinstance(column, sqlalchemy.Column):
 
            # It's a column name
 
            column = getattr(self.c, str(column))
 
        column.create(table=self, *p, **kw)
 

	
 
    def drop_column(self, column, *p, **kw):
 
        """Drop a column, given its name or definition.
 

	
 
        API to :meth:`ChangesetColumn.drop`
 

	
 
        :param column: Column to be droped
 
        :type column: Column instance or string
 
        """
 
        if not isinstance(column, sqlalchemy.Column):
 
            # It's a column name
 
            try:
 
                column = getattr(self.c, str(column))
 
            except AttributeError:
 
                # That column isn't part of the table. We don't need
 
                # its entire definition to drop the column, just its
 
                # name, so create a dummy column with the same name.
 
                column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
 
        column.drop(table=self, *p, **kw)
 

	
 
    def rename(self, name, connection=None, **kwargs):
 
        """Rename this table.
 

	
 
        :param name: New name of the table.
 
        :type name: string
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        engine = self.bind
 
        self.new_name = name
 
        visitorcallable = get_engine_visitor(engine, 'schemachanger')
 
        run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
 

	
 
        # Fix metadata registration
 
        self.name = name
 
        self.deregister()
 
        self._set_parent(self.metadata)
 

	
 
    def _meta_key(self):
 
        return sqlalchemy.schema._get_table_key(self.name, self.schema)
 

	
 
    def deregister(self):
 
        """Remove this table from its metadata"""
 
        key = self._meta_key()
 
        meta = self.metadata
 
        if key in meta.tables:
 
            del meta.tables[key]
 

	
 

	
 
class ChangesetColumn(object):
 
    """Changeset extensions to SQLAlchemy columns."""
 

	
 
    def alter(self, *p, **k):
 
        """Makes a call to :func:`alter_column` for the column this
 
        method is called on.
 
        """
 
        if 'table' not in k:
 
            k['table'] = self.table
 
        if 'engine' not in k:
 
            k['engine'] = k['table'].bind
 
        return alter_column(self, *p, **k)
 

	
 
    def create(self, table=None, index_name=None, unique_name=None,
 
               primary_key_name=None, populate_default=True, connection=None, **kwargs):
 
        """Create this column in the database.
 

	
 
        Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
 
        for most databases.
 

	
 
        :param table: Table instance to create on.
 
        :param index_name: Creates :class:`ChangesetIndex` on this column.
 
        :param unique_name: Creates :class:\
 
`~migrate.changeset.constraint.UniqueConstraint` on this column.
 
        :param primary_key_name: Creates :class:\
 
`~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
 
        :param populate_default: If True, created column will be \
 
populated with defaults
 
        :param connection: reuse connection istead of creating new one.
 
        :type table: Table instance
 
        :type index_name: string
 
        :type unique_name: string
 
        :type primary_key_name: string
 
        :type populate_default: bool
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 

	
 
        :returns: self
 
        """
 
        self.populate_default = populate_default
 
        self.index_name = index_name
 
        self.unique_name = unique_name
 
        self.primary_key_name = primary_key_name
 
        for cons in ('index_name', 'unique_name', 'primary_key_name'):
 
            self._check_sanity_constraints(cons)
 

	
 
        self.add_to_table(table)
 
        engine = self.table.bind
 
        visitorcallable = get_engine_visitor(engine, 'columngenerator')
 
        engine._run_visitor(visitorcallable, self, connection, **kwargs)
 

	
 
        # TODO: reuse existing connection
 
        if self.populate_default and self.default is not None:
 
            stmt = table.update().values({self: engine._execute_default(self.default)})
 
            engine.execute(stmt)
 

	
 
        return self
 

	
 
    def drop(self, table=None, connection=None, **kwargs):
 
        """Drop this column from the database, leaving its table intact.
 

	
 
        ``ALTER TABLE DROP COLUMN``, for most databases.
 

	
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        if table is not None:
 
            self.table = table
 
        engine = self.table.bind
 
        visitorcallable = get_engine_visitor(engine, 'columndropper')
 
        engine._run_visitor(visitorcallable, self, connection, **kwargs)
 
        self.remove_from_table(self.table, unset_table=False)
 
        self.table = None
 
        return self
 

	
 
    def add_to_table(self, table):
 
        if table is not None  and self.table is None:
 
            if SQLA_07:
 
                table.append_column(self)
 
            else:
 
            self._set_parent(table)
 

	
 
    def _col_name_in_constraint(self,cons,name):
 
        return False
 

	
 
    def remove_from_table(self, table, unset_table=True):
 
        # TODO: remove primary keys, constraints, etc
 
        if unset_table:
 
            self.table = None
 

	
 
        to_drop = set()
 
        for index in table.indexes:
 
            columns = []
 
            for col in index.columns:
 
                if col.name!=self.name:
 
                    columns.append(col)
 
            if columns:
 
                index.columns=columns
 
            else:
 
                to_drop.add(index)
 
        table.indexes = table.indexes - to_drop
 

	
 
        to_drop = set()
 
        for cons in table.constraints:
 
            # TODO: deal with other types of constraint
 
            if isinstance(cons,(ForeignKeyConstraint,
 
                                UniqueConstraint)):
 
                for col_name in cons.columns:
 
                    if not isinstance(col_name,basestring):
 
                        col_name = col_name.name
 
                    if self.name==col_name:
 
                        to_drop.add(cons)
 
        table.constraints = table.constraints - to_drop
 

	
 
        if table.c.contains_column(self):
 
            if SQLA_07:
 
                table._columns.remove(self)
 
            else:
 
            table.c.remove(self)
 

	
 
    # TODO: this is fixed in 0.6
 
    def copy_fixed(self, **kw):
 
        """Create a copy of this ``Column``, with all attributes."""
 
        return sqlalchemy.Column(self.name, self.type, self.default,
 
            key=self.key,
 
            primary_key=self.primary_key,
 
            nullable=self.nullable,
 
            quote=self.quote,
 
            index=self.index,
 
            unique=self.unique,
 
            onupdate=self.onupdate,
 
            autoincrement=self.autoincrement,
 
            server_default=self.server_default,
 
            server_onupdate=self.server_onupdate,
 
            *[c.copy(**kw) for c in self.constraints])
 

	
 
    def _check_sanity_constraints(self, name):
 
        """Check if constraints names are correct"""
 
        obj = getattr(self, name)
 
        if (getattr(self, name[:-5]) and not obj):
 
            raise InvalidConstraintError("Column.create() accepts index_name,"
 
            " primary_key_name and unique_name to generate constraints")
 
        if not isinstance(obj, basestring) and obj is not None:
 
            raise InvalidConstraintError(
 
            "%s argument for column must be constraint name" % name)
 

	
 

	
 
class ChangesetIndex(object):
 
    """Changeset extensions to SQLAlchemy Indexes."""
 

	
 
    __visit_name__ = 'index'
 

	
 
    def rename(self, name, connection=None, **kwargs):
 
        """Change the name of an index.
 

	
 
        :param name: New name of the Index.
 
        :type name: string
 
        :param connection: reuse connection istead of creating new one.
 
        :type connection: :class:`sqlalchemy.engine.base.Connection` instance
 
        """
 
        engine = self.table.bind
 
        self.new_name = name
 
        visitorcallable = get_engine_visitor(engine, 'schemachanger')
 
        engine._run_visitor(visitorcallable, self, connection, **kwargs)
 
        self.name = name
 

	
 

	
 
class ChangesetDefaultClause(object):
 
    """Implements comparison between :class:`DefaultClause` instances"""
 

	
 
    def __eq__(self, other):
 
        if isinstance(other, self.__class__):
 
            if self.arg == other.arg:
 
                return True
 

	
 
    def __ne__(self, other):
 
        return not self.__eq__(other)
rhodecode/lib/dbmigrate/migrate/exceptions.py
Show inline comments
 
"""
 
   Provide exception classes for :mod:`migrate`
 
"""
 

	
 

	
 
class Error(Exception):
 
    """Error base class."""
 

	
 

	
 
class ApiError(Error):
 
    """Base class for API errors."""
 

	
 

	
 
class KnownError(ApiError):
 
    """A known error condition."""
 

	
 

	
 
class UsageError(ApiError):
 
    """A known error condition where help should be displayed."""
 

	
 

	
 
class ControlledSchemaError(Error):
 
    """Base class for controlled schema errors."""
 

	
 

	
 
class InvalidVersionError(ControlledSchemaError):
 
    """Invalid version number."""
 

	
 

	
 
class DatabaseNotControlledError(ControlledSchemaError):
 
    """Database should be under version control, but it's not."""
 

	
 

	
 
class DatabaseAlreadyControlledError(ControlledSchemaError):
 
    """Database shouldn't be under version control, but it is"""
 

	
 

	
 
class WrongRepositoryError(ControlledSchemaError):
 
    """This database is under version control by another repository."""
 

	
 

	
 
class NoSuchTableError(ControlledSchemaError):
 
    """The table does not exist."""
 

	
 

	
 
class PathError(Error):
 
    """Base class for path errors."""
 

	
 

	
 
class PathNotFoundError(PathError):
 
    """A path with no file was required; found a file."""
 

	
 

	
 
class PathFoundError(PathError):
 
    """A path with a file was required; found no file."""
 

	
 

	
 
class RepositoryError(Error):
 
    """Base class for repository errors."""
 

	
 

	
 
class InvalidRepositoryError(RepositoryError):
 
    """Invalid repository error."""
 

	
 

	
 
class ScriptError(Error):
 
    """Base class for script errors."""
 

	
 

	
 
class InvalidScriptError(ScriptError):
 
    """Invalid script error."""
 

	
 

	
 
class InvalidVersionError(Error):
 
    """Invalid version error."""
 

	
 
# migrate.changeset
 

	
 
class NotSupportedError(Error):
 
    """Not supported error"""
 

	
 

	
 
class InvalidConstraintError(Error):
 
    """Invalid constraint error"""
 

	
 

	
 
class MigrateDeprecationWarning(DeprecationWarning):
 
    """Warning for deprecated features in Migrate"""
rhodecode/lib/dbmigrate/migrate/versioning/api.py
Show inline comments
 
"""
 
   This module provides an external API to the versioning system.
 

	
 
   .. versionchanged:: 0.6.0
 
    :func:`migrate.versioning.api.test` and schema diff functions
 
    changed order of positional arguments so all accept `url` and `repository`
 
    as first arguments.
 

	
 
   .. versionchanged:: 0.5.4
 
    ``--preview_sql`` displays source file when using SQL scripts.
 
    If Python script is used, it runs the action with mocked engine and
 
    returns captured SQL statements.
 

	
 
   .. versionchanged:: 0.5.4
 
    Deprecated ``--echo`` parameter in favour of new
 
    :func:`migrate.versioning.util.construct_engine` behavior.
 
"""
 

	
 
# Dear migrate developers,
 
#
 
# please do not comment this module using sphinx syntax because its
 
# docstrings are presented as user help and most users cannot
 
# interpret sphinx annotated ReStructuredText.
 
#
 
# Thanks,
 
# Jan Dittberner
 

	
 
import sys
 
import inspect
 
import logging
 

	
 
from rhodecode.lib.dbmigrate.migrate import exceptions
 
from rhodecode.lib.dbmigrate.migrate.versioning import repository, schema, version, \
 
    script as script_ # command name conflict
 
from rhodecode.lib.dbmigrate.migrate.versioning.util import catch_known_errors, with_engine
 

	
 

	
 
log = logging.getLogger(__name__)
 
command_desc = {
 
    'help': 'displays help on a given command',
 
    'create': 'create an empty repository at the specified path',
 
    'script': 'create an empty change Python script',
 
    'script_sql': 'create empty change SQL scripts for given database',
 
    'version': 'display the latest version available in a repository',
 
    'db_version': 'show the current version of the repository under version control',
 
    'source': 'display the Python code for a particular version in this repository',
 
    'version_control': 'mark a database as under this repository\'s version control',
 
    'upgrade': 'upgrade a database to a later version',
 
    'downgrade': 'downgrade a database to an earlier version',
 
    'drop_version_control': 'removes version control from a database',
 
    'manage': 'creates a Python script that runs Migrate with a set of default values',
 
    'test': 'performs the upgrade and downgrade command on the given database',
 
    'compare_model_to_db': 'compare MetaData against the current database state',
 
    'create_model': 'dump the current database as a Python model to stdout',
 
    'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
 
    'update_db_from_model': 'modify the database to match the structure of the current MetaData',
 
}
 
__all__ = command_desc.keys()
 

	
 
Repository = repository.Repository
 
ControlledSchema = schema.ControlledSchema
 
VerNum = version.VerNum
 
PythonScript = script_.PythonScript
 
SqlScript = script_.SqlScript
 

	
 

	
 
# deprecated
 
def help(cmd=None, **opts):
 
    """%prog help COMMAND
 

	
 
    Displays help on a given command.
 
    """
 
    if cmd is None:
 
        raise exceptions.UsageError(None)
 
    try:
 
        func = globals()[cmd]
 
    except:
 
        raise exceptions.UsageError(
 
            "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
 
    ret = func.__doc__
 
    if sys.argv[0]:
 
        ret = ret.replace('%prog', sys.argv[0])
 
    return ret
 

	
 
@catch_known_errors
 
def create(repository, name, **opts):
 
    """%prog create REPOSITORY_PATH NAME [--table=TABLE]
 

	
 
    Create an empty repository at the specified path.
 

	
 
    You can specify the version_table to be used; by default, it is
 
    'migrate_version'.  This table is created in all version-controlled
 
    databases.
 
    """
 
    repo_path = Repository.create(repository, name, **opts)
 

	
 

	
 
@catch_known_errors
 
def script(description, repository, **opts):
 
    """%prog script DESCRIPTION REPOSITORY_PATH
 

	
 
    Create an empty change script using the next unused version number
 
    appended with the given description.
 

	
 
    For instance, manage.py script "Add initial tables" creates:
 
    repository/versions/001_Add_initial_tables.py
 
    """
 
    repo = Repository(repository)
 
    repo.create_script(description, **opts)
 

	
 

	
 
@catch_known_errors
 
def script_sql(database, repository, **opts):
 
    """%prog script_sql DATABASE REPOSITORY_PATH
 
def script_sql(database, description, repository, **opts):
 
    """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH
 

	
 
    Create empty change SQL scripts for given DATABASE, where DATABASE
 
    is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
 
    is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.)
 
    or generic ('default').
 

	
 
    For instance, manage.py script_sql postgres creates:
 
    repository/versions/001_postgres_upgrade.sql and
 
    repository/versions/001_postgres_postgres.sql
 
    For instance, manage.py script_sql postgresql description creates:
 
    repository/versions/001_description_postgresql_upgrade.sql and
 
    repository/versions/001_description_postgresql_postgres.sql
 
    """
 
    repo = Repository(repository)
 
    repo.create_script_sql(database, **opts)
 
    repo.create_script_sql(database, description, **opts)
 

	
 

	
 
def version(repository, **opts):
 
    """%prog version REPOSITORY_PATH
 

	
 
    Display the latest version available in a repository.
 
    """
 
    repo = Repository(repository)
 
    return repo.latest
 

	
 

	
 
@with_engine
 
def db_version(url, repository, **opts):
 
    """%prog db_version URL REPOSITORY_PATH
 

	
 
    Show the current version of the repository with the given
 
    connection string, under version control of the specified
 
    repository.
 

	
 
    The url should be any valid SQLAlchemy connection string.
 
    """
 
    engine = opts.pop('engine')
 
    schema = ControlledSchema(engine, repository)
 
    return schema.version
 

	
 

	
 
def source(version, dest=None, repository=None, **opts):
 
    """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
 

	
 
    Display the Python code for a particular version in this
 
    repository.  Save it to the file at DESTINATION or, if omitted,
 
    send to stdout.
 
    """
 
    if repository is None:
 
        raise exceptions.UsageError("A repository must be specified")
 
    repo = Repository(repository)
 
    ret = repo.version(version).script().source()
 
    if dest is not None:
 
        dest = open(dest, 'w')
 
        dest.write(ret)
 
        dest.close()
 
        ret = None
 
    return ret
 

	
 

	
 
def upgrade(url, repository, version=None, **opts):
 
    """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
 

	
 
    Upgrade a database to a later version.
 

	
 
    This runs the upgrade() function defined in your change scripts.
 

	
 
    By default, the database is updated to the latest available
 
    version. You may specify a version instead, if you wish.
 

	
 
    You may preview the Python or SQL code to be executed, rather than
 
    actually executing it, using the appropriate 'preview' option.
 
    """
 
    err = "Cannot upgrade a database of version %s to version %s. "\
 
        "Try 'downgrade' instead."
 
    return _migrate(url, repository, version, upgrade=True, err=err, **opts)
 

	
 

	
 
def downgrade(url, repository, version, **opts):
 
    """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
 

	
 
    Downgrade a database to an earlier version.
 

	
 
    This is the reverse of upgrade; this runs the downgrade() function
 
    defined in your change scripts.
 

	
 
    You may preview the Python or SQL code to be executed, rather than
 
    actually executing it, using the appropriate 'preview' option.
 
    """
 
    err = "Cannot downgrade a database of version %s to version %s. "\
 
        "Try 'upgrade' instead."
 
    return _migrate(url, repository, version, upgrade=False, err=err, **opts)
 

	
 
@with_engine
 
def test(url, repository, **opts):
 
    """%prog test URL REPOSITORY_PATH [VERSION]
 

	
 
    Performs the upgrade and downgrade option on the given
 
    database. This is not a real test and may leave the database in a
 
    bad state. You should therefore better run the test on a copy of
 
    your database.
 
    """
 
    engine = opts.pop('engine')
 
    repos = Repository(repository)
 
    script = repos.version(None).script()
 

	
 
    # Upgrade
 
    log.info("Upgrading...")
 
    script.run(engine, 1)
 
    log.info("done")
 

	
 
    log.info("Downgrading...")
 
    script.run(engine, -1)
 
    log.info("done")
 
    log.info("Success")
 

	
 

	
 
@with_engine
 
def version_control(url, repository, version=None, **opts):
 
    """%prog version_control URL REPOSITORY_PATH [VERSION]
 

	
 
    Mark a database as under this repository's version control.
 

	
 
    Once a database is under version control, schema changes should
 
    only be done via change scripts in this repository.
 

	
 
    This creates the table version_table in the database.
 

	
 
    The url should be any valid SQLAlchemy connection string.
 

	
 
    By default, the database begins at version 0 and is assumed to be
 
    empty.  If the database is not empty, you may specify a version at
 
    which to begin instead. No attempt is made to verify this
 
    version's correctness - the database schema is expected to be
 
    identical to what it would be if the database were created from
 
    scratch.
 
    """
 
    engine = opts.pop('engine')
 
    ControlledSchema.create(engine, repository, version)
 

	
 

	
 
@with_engine
 
def drop_version_control(url, repository, **opts):
 
    """%prog drop_version_control URL REPOSITORY_PATH
 

	
 
    Removes version control from a database.
 
    """
 
    engine = opts.pop('engine')
 
    schema = ControlledSchema(engine, repository)
 
    schema.drop()
 

	
 

	
 
def manage(file, **opts):
 
    """%prog manage FILENAME [VARIABLES...]
 

	
 
    Creates a script that runs Migrate with a set of default values.
 

	
 
    For example::
 

	
 
        %prog manage manage.py --repository=/path/to/repository \
 
--url=sqlite:///project.db
 

	
 
    would create the script manage.py. The following two commands
 
    would then have exactly the same results::
 

	
 
        python manage.py version
 
        %prog version --repository=/path/to/repository
 
    """
 
    Repository.create_manage_file(file, **opts)
 

	
 

	
 
@with_engine
 
def compare_model_to_db(url, repository, model, **opts):
 
    """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
 

	
 
    Compare the current model (assumed to be a module level variable
 
    of type sqlalchemy.MetaData) against the current database.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
 
    engine = opts.pop('engine')
 
    return ControlledSchema.compare_model_to_db(engine, model, repository)
 

	
 

	
 
@with_engine
 
def create_model(url, repository, **opts):
 
    """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
 

	
 
    Dump the current database as a Python model to stdout.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
 
    engine = opts.pop('engine')
 
    declarative = opts.get('declarative', False)
 
    return ControlledSchema.create_model(engine, repository, declarative)
 

	
 

	
 
@catch_known_errors
 
@with_engine
 
def make_update_script_for_model(url, repository, oldmodel, model, **opts):
 
    """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
 

	
 
    Create a script changing the old Python model to the new (current)
 
    Python model, sending to stdout.
 

	
 
    NOTE: This is EXPERIMENTAL.
 
    """  # TODO: get rid of EXPERIMENTAL label
rhodecode/lib/dbmigrate/migrate/versioning/genmodel.py
Show inline comments
 
"""
 
   Code to generate a Python model from a database or differences
 
   between a model and database.
 

	
 
   Some of this is borrowed heavily from the AutoCode project at:
 
   http://code.google.com/p/sqlautocode/
 
"""
 

	
 
import sys
 
import logging
 

	
 
import sqlalchemy
 

	
 
from rhodecode.lib.dbmigrate import migrate
 
from rhodecode.lib.dbmigrate.migrate import changeset
 

	
 

	
 
log = logging.getLogger(__name__)
 
HEADER = """
 
## File autogenerated by genmodel.py
 

	
 
from sqlalchemy import *
 
meta = MetaData()
 
"""
 

	
 
DECLARATIVE_HEADER = """
 
## File autogenerated by genmodel.py
 

	
 
from sqlalchemy import *
 
from sqlalchemy.ext import declarative
 

	
 
Base = declarative.declarative_base()
 
"""
 

	
 

	
 
class ModelGenerator(object):
 
    """Various transformations from an A, B diff.
 

	
 
    In the implementation, A tends to be called the model and B
 
    the database (although this is not true of all diffs).
 
    The diff is directionless, but transformations apply the diff
 
    in a particular direction, described in the method name.
 
    """
 

	
 
    def __init__(self, diff, engine, declarative=False):
 
        self.diff = diff
 
        self.engine = engine
 
        self.declarative = declarative
 

	
 
    def column_repr(self, col):
 
        kwarg = []
 
        if col.key != col.name:
 
            kwarg.append('key')
 
        if col.primary_key:
 
            col.primary_key = True  # otherwise it dumps it as 1
 
            kwarg.append('primary_key')
 
        if not col.nullable:
 
            kwarg.append('nullable')
 
        if col.onupdate:
 
            kwarg.append('onupdate')
 
        if col.default:
 
            if col.primary_key:
 
                # I found that PostgreSQL automatically creates a
 
                # default value for the sequence, but let's not show
 
                # that.
 
                pass
 
            else:
 
                kwarg.append('default')
 
        ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
 
        args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg]
 

	
 
        # crs: not sure if this is good idea, but it gets rid of extra
 
        # u''
 
        name = col.name.encode('utf8')
 

	
 
        type_ = col.type
 
        for cls in col.type.__class__.__mro__:
 
            if cls.__module__ == 'sqlalchemy.types' and \
 
                not cls.__name__.isupper():
 
                if cls is not type_.__class__:
 
                    type_ = cls()
 
                break
 

	
 
        type_repr = repr(type_)
 
        if type_repr.endswith('()'):
 
            type_repr = type_repr[:-2]
 

	
 
        constraints = [repr(cn) for cn in col.constraints]
 

	
 
        data = {
 
            'name': name,
 
            'type': type_,
 
            'constraints': ', '.join([repr(cn) for cn in col.constraints]),
 
            'args': ks and ks or ''}
 
            'commonStuff': ', '.join([type_repr] + constraints + args),
 
        }
 

	
 
        if data['constraints']:
 
            if data['args']:
 
                data['args'] = ',' + data['args']
 

	
 
        if data['constraints'] or data['args']:
 
            data['maybeComma'] = ','
 
        if self.declarative:
 
            return """%(name)s = Column(%(commonStuff)s)""" % data
 
        else:
 
            data['maybeComma'] = ''
 
            return """Column(%(name)r, %(commonStuff)s)""" % data
 

	
 
        commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
 
        commonStuff = commonStuff.strip()
 
        data['commonStuff'] = commonStuff
 
        if self.declarative:
 
            return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
 
        else:
 
            return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
 

	
 
    def getTableDefn(self, table):
 
    def _getTableDefn(self, table, metaName='meta'):
 
        out = []
 
        tableName = table.name
 
        if self.declarative:
 
            out.append("class %(table)s(Base):" % {'table': tableName})
 
            out.append("  __tablename__ = '%(table)s'" % {'table': tableName})
 
            out.append("    __tablename__ = '%(table)s'\n" %
 
                            {'table': tableName})
 
            for col in table.columns:
 
                out.append("  %s" % self.column_repr(col))
 
            out.append('\n')
 
        else:
 
            out.append("%(table)s = Table('%(table)s', meta," % \
 
                           {'table': tableName})
 
            out.append("%(table)s = Table('%(table)s', %(meta)s," %
 
                       {'table': tableName, 'meta': metaName})
 
            for col in table.columns:
 
                out.append("  %s," % self.column_repr(col))
 
            out.append(")")
 
            out.append(")\n")
 
        return out
 

	
 
    def _get_tables(self,missingA=False,missingB=False,modified=False):
 
        to_process = []
 
        for bool_,names,metadata in (
 
            (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
 
            (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
 
            (modified,self.diff.tables_different,self.diff.metadataA),
 
                ):
 
            if bool_:
 
                for name in names:
 
                    yield metadata.tables.get(name)
 

	
 
    def toPython(self):
 
        """Assume database is current and model is empty."""
 
    def genBDefinition(self):
 
        """Generates the source code for a definition of B.
 

	
 
        Assumes a diff where A is empty.
 

	
 
        Was: toPython. Assume database (B) is current and model (A) is empty.
 
        """
 

	
 
        out = []
 
        if self.declarative:
 
            out.append(DECLARATIVE_HEADER)
 
        else:
 
            out.append(HEADER)
 
        out.append("")
 
        for table in self._get_tables(missingA=True):
 
            out.extend(self.getTableDefn(table))
 
            out.append("")
 
            out.extend(self._getTableDefn(table))
 
        return '\n'.join(out)
 

	
 
    def toUpgradeDowngradePython(self, indent='    '):
 
        ''' Assume model is most current and database is out-of-date. '''
 
        decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
 
                 'meta = MetaData()']
 
        for table in self._get_tables(
 
            missingA=True,missingB=True,modified=True
 
            ):
 
            decls.extend(self.getTableDefn(table))
 
    def genB2AMigration(self, indent='    '):
 
        '''Generate a migration from B to A.
 

	
 
        Was: toUpgradeDowngradePython
 
        Assume model (A) is most current and database (B) is out-of-date.
 
        '''
 

	
 
        decls = ['from migrate.changeset import schema',
 
                 'pre_meta = MetaData()',
 
                 'post_meta = MetaData()',
 
                ]
 
        upgradeCommands = ['pre_meta.bind = migrate_engine',
 
                           'post_meta.bind = migrate_engine']
 
        downgradeCommands = list(upgradeCommands)
 

	
 
        for tn in self.diff.tables_missing_from_A:
 
            pre_table = self.diff.metadataB.tables[tn]
 
            decls.extend(self._getTableDefn(pre_table, metaName='pre_meta'))
 
            upgradeCommands.append(
 
                "pre_meta.tables[%(table)r].drop()" % {'table': tn})
 
            downgradeCommands.append(
 
                "pre_meta.tables[%(table)r].create()" % {'table': tn})
 

	
 
        upgradeCommands, downgradeCommands = [], []
 
        for tableName in self.diff.tables_missing_from_A:
 
            upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
 
            downgradeCommands.append("%(table)s.create()" % \
 
                                         {'table': tableName})
 
        for tableName in self.diff.tables_missing_from_B:
 
            upgradeCommands.append("%(table)s.create()" % {'table': tableName})
 
            downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
 
        for tn in self.diff.tables_missing_from_B:
 
            post_table = self.diff.metadataA.tables[tn]
 
            decls.extend(self._getTableDefn(post_table, metaName='post_meta'))
 
            upgradeCommands.append(
 
                "post_meta.tables[%(table)r].create()" % {'table': tn})
 
            downgradeCommands.append(
 
                "post_meta.tables[%(table)r].drop()" % {'table': tn})
 

	
 
        for tableName in self.diff.tables_different:
 
            dbTable = self.diff.metadataB.tables[tableName]
 
            missingInDatabase, missingInModel, diffDecl = \
 
                self.diff.colDiffs[tableName]
 
            for col in missingInDatabase:
 
                upgradeCommands.append('%s.columns[%r].create()' % (
 
                        modelTable, col.name))
 
                downgradeCommands.append('%s.columns[%r].drop()' % (
 
                        modelTable, col.name))
 
            for col in missingInModel:
 
                upgradeCommands.append('%s.columns[%r].drop()' % (
 
                        modelTable, col.name))
 
                downgradeCommands.append('%s.columns[%r].create()' % (
 
                        modelTable, col.name))
 
            for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
 
        for (tn, td) in self.diff.tables_different.iteritems():
 
            if td.columns_missing_from_A or td.columns_different:
 
                pre_table = self.diff.metadataB.tables[tn]
 
                decls.extend(self._getTableDefn(
 
                    pre_table, metaName='pre_meta'))
 
            if td.columns_missing_from_B or td.columns_different:
 
                post_table = self.diff.metadataA.tables[tn]
 
                decls.extend(self._getTableDefn(
 
                    post_table, metaName='post_meta'))
 

	
 
            for col in td.columns_missing_from_A:
 
                upgradeCommands.append(
 
                    'pre_meta.tables[%r].columns[%r].drop()' % (tn, col))
 
                downgradeCommands.append(
 
                    'pre_meta.tables[%r].columns[%r].create()' % (tn, col))
 
            for col in td.columns_missing_from_B:
 
                upgradeCommands.append(
 
                    'post_meta.tables[%r].columns[%r].create()' % (tn, col))
 
                downgradeCommands.append(
 
                    'post_meta.tables[%r].columns[%r].drop()' % (tn, col))
 
            for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different:
 
                upgradeCommands.append(
 
                    'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
 
                    modelTable, modelCol.name, databaseCol.name))
 
                    tn, modelCol.name, databaseCol.name))
 
                downgradeCommands.append(
 
                    'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
 
                    modelTable, modelCol.name, databaseCol.name))
 
        pre_command = '    meta.bind = migrate_engine'
 
                    tn, modelCol.name, databaseCol.name))
 

	
 
        return (
 
            '\n'.join(decls),
 
            '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
 
            '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
 
            '\n'.join('%s%s' % (indent, line) for line in upgradeCommands),
 
            '\n'.join('%s%s' % (indent, line) for line in downgradeCommands))
 

	
 
    def _db_can_handle_this_change(self,td):
 
        """Check if the database can handle going from B to A."""
 

	
 
        if (td.columns_missing_from_B
 
            and not td.columns_missing_from_A
 
            and not td.columns_different):
 
            # Even sqlite can handle this.
 
            # Even sqlite can handle column additions.
 
            return True
 
        else:
 
            return not self.engine.url.drivername.startswith('sqlite')
 

	
 
    def applyModel(self):
 
        """Apply model to current database."""
 
    def runB2A(self):
 
        """Goes from B to A.
 

	
 
        Was: applyModel. Apply model (A) to current database (B).
 
        """
 

	
 
        meta = sqlalchemy.MetaData(self.engine)
 

	
 
        for table in self._get_tables(missingA=True):
 
            table = table.tometadata(meta)
 
            table.drop()
 
        for table in self._get_tables(missingB=True):
 
            table = table.tometadata(meta)
 
            table.create()
 
        for modelTable in self._get_tables(modified=True):
 
            tableName = modelTable.name
 
            modelTable = modelTable.tometadata(meta)
 
            dbTable = self.diff.metadataB.tables[tableName]
 

	
 
            td = self.diff.tables_different[tableName]
 

	
 
            if self._db_can_handle_this_change(td):
 

	
 
                for col in td.columns_missing_from_B:
 
                    modelTable.columns[col].create()
 
                for col in td.columns_missing_from_A:
 
                    dbTable.columns[col].drop()
 
                # XXX handle column changes here.
 
            else:
 
                # Sqlite doesn't support drop column, so you have to
 
                # do more: create temp table, copy data to it, drop
 
                # old table, create new table, copy data back.
 
                #
 
                # I wonder if this is guaranteed to be unique?
 
                tempName = '_temp_%s' % modelTable.name
 

	
 
                def getCopyStatement():
 
                    preparer = self.engine.dialect.preparer
 
                    commonCols = []
 
                    for modelCol in modelTable.columns:
 
                        if modelCol.name in dbTable.columns:
 
                            commonCols.append(modelCol.name)
 
                    commonColsStr = ', '.join(commonCols)
 
                    return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
 
                        (tableName, commonColsStr, commonColsStr, tempName)
 

	
 
                # Move the data in one transaction, so that we don't
 
                # leave the database in a nasty state.
 
                connection = self.engine.connect()
 
                trans = connection.begin()
 
                try:
 
                    connection.execute(
 
                        'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
 
                            (tempName, modelTable.name))
 
                    # make sure the drop takes place inside our
 
                    # transaction with the bind parameter
 
                    modelTable.drop(bind=connection)
 
                    modelTable.create(bind=connection)
 
                    connection.execute(getCopyStatement())
 
                    connection.execute('DROP TABLE %s' % tempName)
 
                    trans.commit()
 
                except:
 
                    trans.rollback()
 
                    raise
 

	
rhodecode/lib/dbmigrate/migrate/versioning/repository.py
Show inline comments
 
"""
 
   SQLAlchemy migrate repository management.
 
"""
 
import os
 
import shutil
 
import string
 
import logging
 

	
 
from pkg_resources import resource_filename
 
from tempita import Template as TempitaTemplate
 

	
 
from rhodecode.lib.dbmigrate.migrate import exceptions
 
from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
 
from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
 
from rhodecode.lib.dbmigrate.migrate.versioning.config import *
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class Changeset(dict):
 
    """A collection of changes to be applied to a database.
 

	
 
    Changesets are bound to a repository and manage a set of
 
    scripts from that repository.
 

	
 
    Behaves like a dict, for the most part. Keys are ordered based on step value.
 
    """
 

	
 
    def __init__(self, start, *changes, **k):
 
        """
 
        Give a start version; step must be explicitly stated.
 
        """
 
        self.step = k.pop('step', 1)
 
        self.start = version.VerNum(start)
 
        self.end = self.start
 
        for change in changes:
 
            self.add(change)
 

	
 
    def __iter__(self):
 
        return iter(self.items())
 

	
 
    def keys(self):
 
        """
 
        In a series of upgrades x -> y, keys are version x. Sorted.
 
        """
 
        ret = super(Changeset, self).keys()
 
        # Reverse order if downgrading
 
        ret.sort(reverse=(self.step < 1))
 
        return ret
 

	
 
    def values(self):
 
        return [self[k] for k in self.keys()]
 

	
 
    def items(self):
 
        return zip(self.keys(), self.values())
 

	
 
    def add(self, change):
 
        """Add new change to changeset"""
 
        key = self.end
 
        self.end += self.step
 
        self[key] = change
 

	
 
    def run(self, *p, **k):
 
        """Run the changeset scripts"""
 
        for version, script in self:
 
            script.run(*p, **k)
 

	
 

	
 
class Repository(pathed.Pathed):
 
    """A project's change script repository"""
 

	
 
    _config = 'migrate.cfg'
 
    _versions = 'versions'
 

	
 
    def __init__(self, path):
 
        log.debug('Loading repository %s...' % path)
 
        self.verify(path)
 
        super(Repository, self).__init__(path)
 
        self.config = cfgparse.Config(os.path.join(self.path, self._config))
 
        self.versions = version.Collection(os.path.join(self.path,
 
                                                      self._versions))
 
        log.debug('Repository %s loaded successfully' % path)
 
        log.debug('Config: %r' % self.config.to_dict())
 

	
 
    @classmethod
 
    def verify(cls, path):
 
        """
 
        Ensure the target path is a valid repository.
 

	
 
        :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
 
        """
 
        # Ensure the existence of required files
 
        try:
 
            cls.require_found(path)
 
            cls.require_found(os.path.join(path, cls._config))
 
            cls.require_found(os.path.join(path, cls._versions))
 
        except exceptions.PathNotFoundError, e:
 
            raise exceptions.InvalidRepositoryError(path)
 

	
 
    @classmethod
 
    def prepare_config(cls, tmpl_dir, name, options=None):
 
        """
 
        Prepare a project configuration file for a new project.
 

	
 
        :param tmpl_dir: Path to Repository template
 
        :param config_file: Name of the config file in Repository template
 
        :param name: Repository name
 
        :type tmpl_dir: string
 
        :type config_file: string
 
        :type name: string
 
        :returns: Populated config file
 
        """
 
        if options is None:
 
            options = {}
 
        options.setdefault('version_table', 'migrate_version')
 
        options.setdefault('repository_id', name)
 
        options.setdefault('required_dbs', [])
 
        options.setdefault('use_timestamp_numbering', '0')
 

	
 
        tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
 
        ret = TempitaTemplate(tmpl).substitute(options)
 

	
 
        # cleanup
 
        del options['__template_name__']
 

	
 
        return ret
 

	
 
    @classmethod
 
    def create(cls, path, name, **opts):
 
        """Create a repository at a specified path"""
 
        cls.require_notfound(path)
 
        theme = opts.pop('templates_theme', None)
 
        t_path = opts.pop('templates_path', None)
 

	
 
        # Create repository
 
        tmpl_dir = Template(t_path).get_repository(theme=theme)
 
        shutil.copytree(tmpl_dir, path)
 

	
 
        # Edit config defaults
 
        config_text = cls.prepare_config(tmpl_dir, name, options=opts)
 
        fd = open(os.path.join(path, cls._config), 'w')
 
        fd.write(config_text)
 
        fd.close()
 

	
 
        opts['repository_name'] = name
 

	
 
        # Create a management script
 
        manager = os.path.join(path, 'manage.py')
 
        Repository.create_manage_file(manager, templates_theme=theme,
 
            templates_path=t_path, **opts)
 

	
 
        return cls(path)
 

	
 
    def create_script(self, description, **k):
 
        """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
 
        
 
        k['use_timestamp_numbering'] = self.use_timestamp_numbering
 
        self.versions.create_new_python_version(description, **k)
 

	
 
    def create_script_sql(self, database, **k):
 
    def create_script_sql(self, database, description, **k):
 
        """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
 
        self.versions.create_new_sql_version(database, **k)
 
        k['use_timestamp_numbering'] = self.use_timestamp_numbering
 
        self.versions.create_new_sql_version(database, description, **k)
 

	
 
    @property
 
    def latest(self):
 
        """API to :attr:`migrate.versioning.version.Collection.latest`"""
 
        return self.versions.latest
 

	
 
    @property
 
    def version_table(self):
 
        """Returns version_table name specified in config"""
 
        return self.config.get('db_settings', 'version_table')
 

	
 
    @property
 
    def id(self):
 
        """Returns repository id specified in config"""
 
        return self.config.get('db_settings', 'repository_id')
 

	
 
    @property
 
    def use_timestamp_numbering(self):
 
        """Returns use_timestamp_numbering specified in config"""
 
        ts_numbering = self.config.get('db_settings', 'use_timestamp_numbering', raw=True)
 
        
 
        return ts_numbering
 

	
 
    def version(self, *p, **k):
 
        """API to :attr:`migrate.versioning.version.Collection.version`"""
 
        return self.versions.version(*p, **k)
 

	
 
    @classmethod
 
    def clear(cls):
 
        # TODO: deletes repo
 
        super(Repository, cls).clear()
 
        version.Collection.clear()
 

	
 
    def changeset(self, database, start, end=None):
 
        """Create a changeset to migrate this database from ver. start to end/latest.
 

	
 
        :param database: name of database to generate changeset
 
        :param start: version to start at
 
        :param end: version to end at (latest if None given)
 
        :type database: string
 
        :type start: int
 
        :type end: int
 
        :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
 
        """
 
        start = version.VerNum(start)
 

	
 
        if end is None:
 
            end = self.latest
 
        else:
 
            end = version.VerNum(end)
 

	
 
        if start <= end:
 
            step = 1
 
            range_mod = 1
 
            op = 'upgrade'
 
        else:
 
            step = -1
 
            range_mod = 0
 
            op = 'downgrade'
 

	
 
        versions = range(start + range_mod, end + range_mod, step)
 
        changes = [self.version(v).script(database, op) for v in versions]
 
        ret = Changeset(start, step=step, *changes)
 
        return ret
 

	
 
    @classmethod
 
    def create_manage_file(cls, file_, **opts):
 
        """Create a project management script (manage.py)
 

	
 
        :param file_: Destination file to be written
 
        :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
 
        """
 
        mng_file = Template(opts.pop('templates_path', None))\
 
            .get_manage(theme=opts.pop('templates_theme', None))
 

	
 
        tmpl = open(mng_file).read()
 
        fd = open(file_, 'w')
 
        fd.write(TempitaTemplate(tmpl).substitute(opts))
 
        fd.close()
rhodecode/lib/dbmigrate/migrate/versioning/schema.py
Show inline comments
 
"""
 
   Database schema version management.
 
"""
 
import sys
 
import logging
 

	
 
from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
 
    create_engine)
 
from sqlalchemy.sql import and_
 
from sqlalchemy import exceptions as sa_exceptions
 
from sqlalchemy.sql import bindparam
 

	
 
from rhodecode.lib.dbmigrate.migrate import exceptions
 
from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07
 
from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
 
from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
 
from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
 
from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class ControlledSchema(object):
 
    """A database under version control"""
 

	
 
    def __init__(self, engine, repository):
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 
        self.engine = engine
 
        self.repository = repository
 
        self.meta = MetaData(engine)
 
        self.load()
 

	
 
    def __eq__(self, other):
 
        """Compare two schemas by repositories and versions"""
 
        return (self.repository is other.repository \
 
            and self.version == other.version)
 

	
 
    def load(self):
 
        """Load controlled schema version info from DB"""
 
        tname = self.repository.version_table
 
        try:
 
            if not hasattr(self, 'table') or self.table is None:
 
                    self.table = Table(tname, self.meta, autoload=True)
 

	
 
            result = self.engine.execute(self.table.select(
 
                self.table.c.repository_id == str(self.repository.id)))
 

	
 
            data = list(result)[0]
 
        except:
 
            cls, exc, tb = sys.exc_info()
 
            raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
 

	
 
        self.version = data['version']
 
        return data
 

	
 
    def drop(self):
 
        """
 
        Remove version control from a database.
 
        """
 
        if SQLA_07:
 
            try:
 
                self.table.drop()
 
            except sa_exceptions.DatabaseError:
 
                raise exceptions.DatabaseNotControlledError(str(self.table))
 
        else:
 
        try:
 
            self.table.drop()
 
        except (sa_exceptions.SQLError):
 
            raise exceptions.DatabaseNotControlledError(str(self.table))
 

	
 
    def changeset(self, version=None):
 
        """API to Changeset creation.
 

	
 
        Uses self.version for start version and engine.name
 
        to get database name.
 
        """
 
        database = self.engine.name
 
        start_ver = self.version
 
        changeset = self.repository.changeset(database, start_ver, version)
 
        return changeset
 

	
 
    def runchange(self, ver, change, step):
 
        startver = ver
 
        endver = ver + step
 
        # Current database version must be correct! Don't run if corrupt!
 
        if self.version != startver:
 
            raise exceptions.InvalidVersionError("%s is not %s" % \
 
                                                     (self.version, startver))
 
        # Run the change
 
        change.run(self.engine, step)
 

	
 
        # Update/refresh database version
 
        self.update_repository_table(startver, endver)
 
        self.load()
 

	
 
    def update_repository_table(self, startver, endver):
 
        """Update version_table with new information"""
 
        update = self.table.update(and_(self.table.c.version == int(startver),
 
             self.table.c.repository_id == str(self.repository.id)))
 
        self.engine.execute(update, version=int(endver))
 

	
 
    def upgrade(self, version=None):
 
        """
 
        Upgrade (or downgrade) to a specified version, or latest version.
 
        """
 
        changeset = self.changeset(version)
 
        for ver, change in changeset:
 
            self.runchange(ver, change, changeset.step)
 

	
 
    def update_db_from_model(self, model):
 
        """
 
        Modify the database to match the structure of the current Python model.
 
        """
 
        model = load_model(model)
 

	
 
        diff = schemadiff.getDiffOfModelAgainstDatabase(
 
            model, self.engine, excludeTables=[self.repository.version_table]
 
            )
 
        genmodel.ModelGenerator(diff,self.engine).applyModel()
 
        genmodel.ModelGenerator(diff,self.engine).runB2A()
 

	
 
        self.update_repository_table(self.version, int(self.repository.latest))
 

	
 
        self.load()
 

	
 
    @classmethod
 
    def create(cls, engine, repository, version=None):
 
        """
 
        Declare a database to be under a repository's version control.
 

	
 
        :raises: :exc:`DatabaseAlreadyControlledError`
 
        :returns: :class:`ControlledSchema`
 
        """
 
        # Confirm that the version # is valid: positive, integer,
 
        # exists in repos
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 
        version = cls._validate_version(repository, version)
 
        table = cls._create_table_version(engine, repository, version)
 
        # TODO: history table
 
        # Load repository information and return
 
        return cls(engine, repository)
 

	
 
    @classmethod
 
    def _validate_version(cls, repository, version):
 
        """
 
        Ensures this is a valid version number for this repository.
 

	
 
        :raises: :exc:`InvalidVersionError` if invalid
 
        :return: valid version number
 
        """
 
        if version is None:
 
            version = 0
 
        try:
 
            version = VerNum(version) # raises valueerror
 
            if version < 0 or version > repository.latest:
 
                raise ValueError()
 
        except ValueError:
 
            raise exceptions.InvalidVersionError(version)
 
        return version
 

	
 
    @classmethod
 
    def _create_table_version(cls, engine, repository, version):
 
        """
 
        Creates the versioning table in a database.
 

	
 
        :raises: :exc:`DatabaseAlreadyControlledError`
 
        """
 
        # Create tables
 
        tname = repository.version_table
 
        meta = MetaData(engine)
 

	
 
        table = Table(
 
            tname, meta,
 
            Column('repository_id', String(250), primary_key=True),
 
            Column('repository_path', Text),
 
            Column('version', Integer), )
 

	
 
        # there can be multiple repositories/schemas in the same db
 
        if not table.exists():
 
            table.create()
 

	
 
        # test for existing repository_id
 
        s = table.select(table.c.repository_id == bindparam("repository_id"))
 
        result = engine.execute(s, repository_id=repository.id)
 
        if result.fetchone():
 
            raise exceptions.DatabaseAlreadyControlledError
 

	
 
        # Insert data
 
        engine.execute(table.insert().values(
 
                           repository_id=repository.id,
 
                           repository_path=repository.path,
 
                           version=int(version)))
 
        return table
 

	
 
    @classmethod
 
    def compare_model_to_db(cls, engine, model, repository):
 
        """
 
        Compare the current model against the current database.
 
        """
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 
        model = load_model(model)
 

	
 
        diff = schemadiff.getDiffOfModelAgainstDatabase(
 
            model, engine, excludeTables=[repository.version_table])
 
        return diff
 

	
 
    @classmethod
 
    def create_model(cls, engine, repository, declarative=False):
 
        """
 
        Dump the current database as a Python model.
 
        """
 
        if isinstance(repository, basestring):
 
            repository = Repository(repository)
 

	
 
        diff = schemadiff.getDiffOfModelAgainstDatabase(
 
            MetaData(), engine, excludeTables=[repository.version_table]
 
            )
 
        return genmodel.ModelGenerator(diff, engine, declarative).toPython()
 
        return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition()
rhodecode/lib/dbmigrate/migrate/versioning/script/py.py
Show inline comments
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
import shutil
 
import warnings
 
import logging
 
import inspect
 
from StringIO import StringIO
 

	
 
from rhodecode.lib.dbmigrate import migrate
 
from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
 
from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
 
from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
 
from rhodecode.lib.dbmigrate.migrate.versioning.script import base
 
from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine
 
from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
 

	
 
log = logging.getLogger(__name__)
 
__all__ = ['PythonScript']
 

	
 

	
 
class PythonScript(base.BaseScript):
 
    """Base for Python scripts"""
 

	
 
    @classmethod
 
    def create(cls, path, **opts):
 
        """Create an empty migration script at specified path
 

	
 
        :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
 
        cls.require_notfound(path)
 

	
 
        src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
 
        shutil.copy(src, path)
 

	
 
        return cls(path)
 

	
 
    @classmethod
 
    def make_update_script_for_model(cls, engine, oldmodel,
 
                                     model, repository, **opts):
 
        """Create a migration script based on difference between two SA models.
 

	
 
        :param repository: path to migrate repository
 
        :param oldmodel: dotted.module.name:SAClass or SAClass object
 
        :param model: dotted.module.name:SAClass or SAClass object
 
        :param engine: SQLAlchemy engine
 
        :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
 
        :type oldmodel: string or Class
 
        :type model: string or Class
 
        :type engine: Engine instance
 
        :returns: Upgrade / Downgrade script
 
        :rtype: string
 
        """
 

	
 
        if isinstance(repository, basestring):
 
            # oh dear, an import cycle!
 
            from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
 
            repository = Repository(repository)
 

	
 
        oldmodel = load_model(oldmodel)
 
        model = load_model(model)
 

	
 
        # Compute differences.
 
        diff = schemadiff.getDiffOfModelAgainstModel(
 
            model,
 
            oldmodel,
 
            model,
 
            excludeTables=[repository.version_table])
 
        # TODO: diff can be False (there is no difference?)
 
        decls, upgradeCommands, downgradeCommands = \
 
            genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
 
            genmodel.ModelGenerator(diff,engine).genB2AMigration()
 

	
 
        # Store differences into file.
 
        src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
 
        f = open(src)
 
        contents = f.read()
 
        f.close()
 

	
 
        # generate source
 
        search = 'def upgrade(migrate_engine):'
 
        contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
 
        if upgradeCommands:
 
            contents = contents.replace('    pass', upgradeCommands, 1)
 
        if downgradeCommands:
 
            contents = contents.replace('    pass', downgradeCommands, 1)
 
        return contents
 

	
 
    @classmethod
 
    def verify_module(cls, path):
 
        """Ensure path is a valid script
 

	
 
        :param path: Script location
 
        :type path: string
 
        :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
 
        :returns: Python module
 
        """
 
        # Try to import and get the upgrade() func
 
        module = import_path(path)
 
        try:
 
            assert callable(module.upgrade)
 
        except Exception, e:
 
            raise InvalidScriptError(path + ': %s' % str(e))
 
        return module
 

	
 
    def preview_sql(self, url, step, **args):
 
        """Mocks SQLAlchemy Engine to store all executed calls in a string
 
        and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
 

	
 
        :returns: SQL file
 
        """
 
        buf = StringIO()
 
        args['engine_arg_strategy'] = 'mock'
 
        args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
 

	
 
        @with_engine
 
        def go(url, step, **kw):
 
            engine = kw.pop('engine')
 
            self.run(engine, step)
 
            return buf.getvalue()
 

	
 
        return go(url, step, **args)
 

	
 
    def run(self, engine, step):
 
        """Core method of Script file.
 
        Exectues :func:`update` or :func:`downgrade` functions
 

	
 
        :param engine: SQLAlchemy Engine
 
        :param step: Operation to run
 
        :type engine: string
 
        :type step: int
 
        """
 
        if step > 0:
 
            op = 'upgrade'
 
        elif step < 0:
 
            op = 'downgrade'
 
        else:
 
            raise ScriptError("%d is not a valid step" % step)
 

	
 
        funcname = base.operations[op]
 
        script_func = self._func(funcname)
 

	
 
        # check for old way of using engine
 
        if not inspect.getargspec(script_func)[0]:
 
            raise TypeError("upgrade/downgrade functions must accept engine"
 
                " parameter (since version 0.5.4)")
 

	
 
        script_func(engine)
 

	
 
    @property
 
    def module(self):
 
        """Calls :meth:`migrate.versioning.script.py.verify_module`
 
        and returns it.
 
        """
 
        if not hasattr(self, '_module'):
 
            self._module = self.verify_module(self.path)
 
        return self._module
 

	
 
    def _func(self, funcname):
 
        if not hasattr(self.module, funcname):
 
            msg = "Function '%s' is not defined in this script"
 
            raise ScriptError(msg % funcname)
 
        return getattr(self.module, funcname)
rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/migrate.cfg
Show inline comments
 
[db_settings]
 
# Used to identify which repository this database is versioned under.
 
# You can use the name of your project.
 
repository_id={{ locals().pop('repository_id') }}
 

	
 
# The name of the database table used to track the schema version.
 
# This name shouldn't already be used by your project.
 
# If this is changed once a database is under version control, you'll need to 
 
# change the table name in each database too. 
 
version_table={{ locals().pop('version_table') }}
 

	
 
# When committing a change script, Migrate will attempt to generate the 
 
# sql for all supported databases; normally, if one of them fails - probably
 
# because you don't have that database installed - it is ignored and the 
 
# commit continues, perhaps ending successfully. 
 
# Databases in this list MUST compile successfully during a commit, or the 
 
# entire commit will fail. List the databases your application will actually 
 
# be using to ensure your updates to that database work properly.
 
# This must be a list; example: ['postgres','sqlite']
 
required_dbs={{ locals().pop('required_dbs') }}
 

	
 
# When creating new change scripts, Migrate will stamp the new script with
 
# a version number. By default this is latest_version + 1. You can set this
 
# to 'true' to tell Migrate to use the UTC timestamp instead.
 
use_timestamp_numbering='false'
 
\ No newline at end of file
rhodecode/lib/dbmigrate/migrate/versioning/version.py
Show inline comments
 
#!/usr/bin/env python
 
# -*- coding: utf-8 -*-
 

	
 
import os
 
import re
 
import shutil
 
import logging
 

	
 
from rhodecode.lib.dbmigrate.migrate import exceptions
 
from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
 
from datetime import datetime
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 
class VerNum(object):
 
    """A version number that behaves like a string and int at the same time"""
 

	
 
    _instances = dict()
 

	
 
    def __new__(cls, value):
 
        val = str(value)
 
        if val not in cls._instances:
 
            cls._instances[val] = super(VerNum, cls).__new__(cls)
 
        ret = cls._instances[val]
 
        return ret
 

	
 
    def __init__(self,value):
 
        self.value = str(int(value))
 
        if self < 0:
 
            raise ValueError("Version number cannot be negative")
 

	
 
    def __add__(self, value):
 
        ret = int(self) + int(value)
 
        return VerNum(ret)
 

	
 
    def __sub__(self, value):
 
        return self + (int(value) * -1)
 

	
 
    def __cmp__(self, value):
 
        return int(self) - int(value)
 

	
 
    def __repr__(self):
 
        return "<VerNum(%s)>" % self.value
 

	
 
    def __str__(self):
 
        return str(self.value)
 

	
 
    def __int__(self):
 
        return int(self.value)
 

	
 

	
 
class Collection(pathed.Pathed):
 
    """A collection of versioning scripts in a repository"""
 

	
 
    FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
 

	
 
    def __init__(self, path):
 
        """Collect current version scripts in repository
 
        and store them in self.versions
 
        """
 
        super(Collection, self).__init__(path)
 

	
 
        # Create temporary list of files, allowing skipped version numbers.
 
        files = os.listdir(path)
 
        if '1' in files:
 
            # deprecation
 
            raise Exception('It looks like you have a repository in the old '
 
                'format (with directories for each version). '
 
                'Please convert repository before proceeding.')
 

	
 
        tempVersions = dict()
 
        for filename in files:
 
            match = self.FILENAME_WITH_VERSION.match(filename)
 
            if match:
 
                num = int(match.group(1))
 
                tempVersions.setdefault(num, []).append(filename)
 
            else:
 
                pass  # Must be a helper file or something, let's ignore it.
 

	
 
        # Create the versions member where the keys
 
        # are VerNum's and the values are Version's.
 
        self.versions = dict()
 
        for num, files in tempVersions.items():
 
            self.versions[VerNum(num)] = Version(num, path, files)
 

	
 
    @property
 
    def latest(self):
 
        """:returns: Latest version in Collection"""
 
        return max([VerNum(0)] + self.versions.keys())
 

	
 
    def _next_ver_num(self, use_timestamp_numbering):
 
        print use_timestamp_numbering
 
        if use_timestamp_numbering == True:
 
            print "Creating new timestamp version!"
 
            return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
 
        else:
 
            return self.latest + 1
 

	
 
    def create_new_python_version(self, description, **k):
 
        """Create Python files for new version"""
 
        ver = self.latest + 1
 
        ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
 
        extra = str_to_filename(description)
 

	
 
        if extra:
 
            if extra == '_':
 
                extra = ''
 
            elif not extra.startswith('_'):
 
                extra = '_%s' % extra
 

	
 
        filename = '%03d%s.py' % (ver, extra)
 
        filepath = self._version_path(filename)
 

	
 
        script.PythonScript.create(filepath, **k)
 
        self.versions[ver] = Version(ver, self.path, [filename])
 

	
 
    def create_new_sql_version(self, database, **k):
 
    def create_new_sql_version(self, database, description, **k):
 
        """Create SQL files for new version"""
 
        ver = self.latest + 1
 
        ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
 
        self.versions[ver] = Version(ver, self.path, [])
 

	
 
        extra = str_to_filename(description)
 

	
 
        if extra:
 
            if extra == '_':
 
                extra = ''
 
            elif not extra.startswith('_'):
 
                extra = '_%s' % extra
 

	
 
        # Create new files.
 
        for op in ('upgrade', 'downgrade'):
 
            filename = '%03d_%s_%s.sql' % (ver, database, op)
 
            filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
 
            filepath = self._version_path(filename)
 
            script.SqlScript.create(filepath, **k)
 
            self.versions[ver].add_script(filepath)
 

	
 
    def version(self, vernum=None):
 
        """Returns latest Version if vernum is not given.
 
        Otherwise, returns wanted version"""
 
        if vernum is None:
 
            vernum = self.latest
 
        return self.versions[VerNum(vernum)]
 

	
 
    @classmethod
 
    def clear(cls):
 
        super(Collection, cls).clear()
 

	
 
    def _version_path(self, ver):
 
        """Returns path of file in versions repository"""
 
        return os.path.join(self.path, str(ver))
 

	
 

	
 
class Version(object):
 
    """A single version in a collection
 
    :param vernum: Version Number
 
    :param path: Path to script files
 
    :param filelist: List of scripts
 
    :type vernum: int, VerNum
 
    :type path: string
 
    :type filelist: list
 
    """
 

	
 
    def __init__(self, vernum, path, filelist):
 
        self.version = VerNum(vernum)
 

	
 
        # Collect scripts in this folder
 
        self.sql = dict()
 
        self.python = None
 

	
 
        for script in filelist:
 
            self.add_script(os.path.join(path, script))
 

	
 
    def script(self, database=None, operation=None):
 
        """Returns SQL or Python Script"""
 
        for db in (database, 'default'):
 
            # Try to return a .sql script first
 
            try:
 
                return self.sql[db][operation]
 
            except KeyError:
 
                continue  # No .sql script exists
 

	
 
        # TODO: maybe add force Python parameter?
 
        ret = self.python
 

	
 
        assert ret is not None, \
 
            "There is no script for %d version" % self.version
 
        return ret
 

	
 
    def add_script(self, path):
 
        """Add script to Collection/Version"""
 
        if path.endswith(Extensions.py):
 
            self._add_script_py(path)
 
        elif path.endswith(Extensions.sql):
 
            self._add_script_sql(path)
 

	
 
    SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
 
    SQL_FILENAME = re.compile(r'^.*\.sql')
 

	
 
    def _add_script_sql(self, path):
 
        basename = os.path.basename(path)
 
        match = self.SQL_FILENAME.match(basename)
 

	
 
        if match:
 
            version, dbms, op = match.group(1), match.group(2), match.group(3)
 
            basename = basename.replace('.sql', '')
 
            parts = basename.split('_')
 
            if len(parts) < 3:
 
                raise exceptions.ScriptError(
 
                    "Invalid SQL script name %s " % basename + \
 
                    "(needs to be ###_description_database_operation.sql)")
 
            version = parts[0]
 
            op = parts[-1]
 
            dbms = parts[-2]
 
        else:
 
            raise exceptions.ScriptError(
 
                "Invalid SQL script name %s " % basename + \
 
                "(needs to be ###_database_operation.sql)")
 
                "(needs to be ###_description_database_operation.sql)")
 

	
 
        # File the script into a dictionary
 
        self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
 

	
 
    def _add_script_py(self, path):
 
        if self.python is not None:
 
            raise exceptions.ScriptError('You can only have one Python script '
 
                'per version, but you have: %s and %s' % (self.python, path))
 
        self.python = script.PythonScript(path)
 

	
 

	
 
class Extensions:
 
    """A namespace for file extensions"""
 
    py = 'py'
 
    sql = 'sql'
 

	
 
def str_to_filename(s):
 
    """Replaces spaces, (double and single) quotes
 
    and double underscores to underscores
 
    """
 

	
 
    s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
 
    while '__' in s:
 
        s = s.replace('__', '_')
 
    return s
rhodecode/lib/dbmigrate/versions/003_version_1_2_0.py
Show inline comments
 
import logging
 
import datetime
 

	
 
from sqlalchemy import *
 
from sqlalchemy.exc import DatabaseError
 
from sqlalchemy.orm import relation, backref, class_mapper
 
from sqlalchemy.orm.session import Session
 

	
 
from rhodecode.lib.dbmigrate.migrate import *
 
from rhodecode.lib.dbmigrate.migrate.changeset import *
 

	
 
from rhodecode.model.meta import Base
 

	
 
log = logging.getLogger(__name__)
 

	
 
def upgrade(migrate_engine):
 
    """ Upgrade operations go here.
 
    Don't create your own engine; bind migrate_engine to your metadata
 
    """
 

	
 
    #==========================================================================
 
    # Add table `groups``
 
    #==========================================================================
 
    from rhodecode.model.db import Group
 
    Group().__table__.create()
 

	
 
    #==========================================================================
 
    # Add table `group_to_perm`
 
    #==========================================================================
 
    from rhodecode.model.db import GroupToPerm
 
    GroupToPerm().__table__.create()
 

	
 
    #==========================================================================
 
    # Add table `users_groups`
 
    #==========================================================================
 
    from rhodecode.model.db import UsersGroup
 
    UsersGroup().__table__.create()
 

	
 
    #==========================================================================
 
    # Add table `users_groups_members`
 
    #==========================================================================
 
    from rhodecode.model.db import UsersGroupMember
 
    UsersGroupMember().__table__.create()
 

	
 
    #==========================================================================
 
    # Add table `users_group_repo_to_perm`
 
    #==========================================================================
 
    from rhodecode.model.db import UsersGroupRepoToPerm
 
    UsersGroupRepoToPerm().__table__.create()
 

	
 
    #==========================================================================
 
    # Add table `users_group_to_perm`
 
    #==========================================================================
 
    from rhodecode.model.db import UsersGroupToPerm
 
    UsersGroupToPerm().__table__.create()
 

	
 
    #==========================================================================
 
    # Upgrade of `users` table
 
    #==========================================================================
 
    from rhodecode.model.db import User
 

	
 
    #add column
 
    ldap_dn = Column("ldap_dn", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    ldap_dn.create(User().__table__)
 

	
 
    api_key = Column("api_key", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
 
    api_key.create(User().__table__)
 

	
 
    #remove old column
 
    is_ldap = Column("is_ldap", Boolean(), nullable=False, unique=None, default=False)
 
    is_ldap.drop(User().__table__)
 

	
 

	
 
    #==========================================================================
 
    # Upgrade of `repositories` table
 
    #==========================================================================
 
    from rhodecode.model.db import Repository
 

	
 
    #ADD downloads column#
 
    enable_downloads = Column("downloads", Boolean(), nullable=True, unique=None, default=True)
 
    enable_downloads.create(Repository().__table__)
 

	
 
    #ADD column created_on
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=True,
 
                        unique=None, default=datetime.datetime.now)
 
    created_on.create(Repository().__table__)
 

	
 
    #ADD group_id column#
 
    group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'),
 
                  nullable=True, unique=False, default=None)
 

	
 
    group_id.create(Repository().__table__)
 

	
 

	
 
    #ADD clone_uri column#
 

	
 
    clone_uri = Column("clone_uri", String(length=255, convert_unicode=False,
 
                                           assert_unicode=None),
 
                        nullable=True, unique=False, default=None)
 

	
 
    clone_uri.create(Repository().__table__)
 

	
 

	
 
    #==========================================================================
 
    # Upgrade of `user_followings` table
 
    #==========================================================================
 

	
 
    follows_from = Column('follows_from', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
 
    follows_from.create(Repository().__table__)
 

	
 
    return
 

	
 

	
 
def downgrade(migrate_engine):
 
    meta = MetaData()
 
    meta.bind = migrate_engine
0 comments (0 inline, 0 general)