Changeset - cc1ab5ef6686
[Not reviewed]
default
0 25 0
Mads Kiilerich - 11 years ago 2015-01-06 00:54:36
madski@unity3d.com
cleanup: avoid some 'except Exception' catching - catch specific exceptions or log it and show what happened

This has a risk of introducing regressions ... but we want to get rid of all
exception muting and make the whole system less fragile and easier to debug.
25 files changed with 287 insertions and 416 deletions:
0 comments (0 inline, 0 general)
kallithea/bin/kallithea_api.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.bin.api
 
~~~~~~~~~~~~~~~~~
 

	
 
Api CLI client for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Jun 3, 2012
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
from __future__ import with_statement
 
import sys
 
import argparse
 

	
 
from kallithea.bin.base import json, api_call, RcConf, FORMAT_JSON, FORMAT_PRETTY
 

	
 

	
 
def argparser(argv):
 
    usage = (
 
      "kallithea-api [-h] [--format=FORMAT] [--apikey=APIKEY] [--apihost=APIHOST] "
 
      "[--config=CONFIG] [--save-config] "
 
      "METHOD <key:val> <key2:val> ...\n"
 
      "Create config file: kallithea-api --apikey=<key> --apihost=http://your.kallithea.server --save-config"
 
    )
 

	
 
    parser = argparse.ArgumentParser(description='Kallithea API cli',
 
                                     usage=usage)
 

	
 
    ## config
 
    group = parser.add_argument_group('config')
 
    group.add_argument('--apikey', help='api access key')
 
    group.add_argument('--apihost', help='api host')
 
    group.add_argument('--config', help='config file')
 
    group.add_argument('--save-config', action='store_true', help='save the given config into a file')
 

	
 
    group = parser.add_argument_group('API')
 
    group.add_argument('method', metavar='METHOD', nargs='?', type=str, default=None,
 
            help='API method name to call followed by key:value attributes',
 
    )
 
    group.add_argument('--format', dest='format', type=str,
 
            help='output format default: `%s` can '
 
                 'be also `%s`' % (FORMAT_PRETTY, FORMAT_JSON),
 
            default=FORMAT_PRETTY
 
    )
 
    args, other = parser.parse_known_args()
 
    return parser, args, other
 

	
 

	
 
def main(argv=None):
 
    """
 
    Main execution function for cli
 

	
 
    :param argv:
 
    """
 
    if argv is None:
 
        argv = sys.argv
 

	
 
    conf = None
 
    parser, args, other = argparser(argv)
 

	
 
    api_credentials_given = (args.apikey and args.apihost)
 
    if args.save_config:
 
        if not api_credentials_given:
 
            raise parser.error('--save-config requires --apikey and --apihost')
 
        conf = RcConf(config_location=args.config,
 
                      autocreate=True, config={'apikey': args.apikey,
 
                                               'apihost': args.apihost})
 
        sys.exit()
 

	
 
    if not conf:
 
        conf = RcConf(config_location=args.config, autoload=True)
 
        if not conf:
 
            if not api_credentials_given:
 
                parser.error('Could not find config file and missing '
 
                             '--apikey or --apihost in params')
 

	
 
    apikey = args.apikey or conf['apikey']
 
    apihost = args.apihost or conf['apihost']
 
    method = args.method
 

	
 
    # if we don't have method here it's an error
 
    if not method:
 
        parser.error('Please specify method name')
 

	
 
    try:
 
        margs = dict(map(lambda s: s.split(':', 1), other))
 
    except Exception:
 
    except ValueError:
 
        sys.stderr.write('Error parsing arguments \n')
 
        sys.exit()
 
    if args.format == FORMAT_PRETTY:
 
        print 'Calling method %s => %s' % (method, apihost)
 

	
 
    json_resp = api_call(apikey, apihost, method, **margs)
 
    error_prefix = ''
 
    if json_resp['error']:
 
        error_prefix = 'ERROR:'
 
        json_data = json_resp['error']
 
    else:
 
        json_data = json_resp['result']
 
    if args.format == FORMAT_JSON:
 
        print json.dumps(json_data)
 
    elif args.format == FORMAT_PRETTY:
 
        print 'Server response \n%s%s' % (
 
            error_prefix, json.dumps(json_data, indent=4, sort_keys=True)
 
        )
 
    return 0
 

	
 
if __name__ == '__main__':
 
    sys.exit(main(sys.argv))
kallithea/bin/kallithea_config.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.bin.kallithea_config
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
configuration generator for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Jun 18, 2013
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 

	
 
from __future__ import with_statement
 
import os
 
import sys
 
import uuid
 
import argparse
 
from mako.template import Template
 
TMPL = 'template.ini.mako'
 
here = os.path.dirname(os.path.abspath(__file__))
 

	
 
def argparser(argv):
 
    usage = (
 
      "kallithea-config [-h] [--filename=FILENAME] [--template=TEMPLATE] \n"
 
      "VARS optional specify extra template variable that will be available in "
 
      "template. Use comma separated key=val format eg.\n"
 
      "key1=val1,port=5000,host=127.0.0.1,elements='a\,b\,c'\n"
 
    )
 

	
 
    parser = argparse.ArgumentParser(
 
        description='Kallithea CONFIG generator with variable replacement',
 
        usage=usage
 
    )
 

	
 
    ## config
 
    group = parser.add_argument_group('CONFIG')
 
    group.add_argument('--filename', help='Output ini filename.')
 
    group.add_argument('--template', help='Mako template file to use instead of '
 
                                          'the default builtin template')
 
    group.add_argument('--raw', help='Store given mako template as raw without '
 
                                     'parsing. Use this to create custom template '
 
                                     'initially', action='store_true')
 
    group.add_argument('--show-defaults', help='Show all default variables for '
 
                                               'builtin template', action='store_true')
 
    args, other = parser.parse_known_args()
 
    return parser, args, other
 

	
 

	
 
def _escape_split(text, sep):
 
    """
 
    Allows for escaping of the separator: e.g. arg='foo\, bar'
 

	
 
    It should be noted that the way bash et. al. do command line parsing, those
 
    single quotes are required. a shameless ripoff from fabric project.
 

	
 
    """
 
    escaped_sep = r'\%s' % sep
 

	
 
    if escaped_sep not in text:
 
        return text.split(sep)
 

	
 
    before, _, after = text.partition(escaped_sep)
 
    startlist = before.split(sep)  # a regular split is fine here
 
    unfinished = startlist[-1]
 
    startlist = startlist[:-1]
 

	
 
    # recurse because there may be more escaped separators
 
    endlist = _escape_split(after, sep)
 

	
 
    # finish building the escaped value. we use endlist[0] becaue the first
 
    # part of the string sent in recursion is the rest of the escaped value.
 
    unfinished += sep + endlist[0]
 

	
 
    return startlist + [unfinished] + endlist[1:]  # put together all the parts
 

	
 
def _run(argv):
 
    parser, args, other = argparser(argv)
 
    if not len(sys.argv) > 1:
 
        print parser.print_help()
 
        sys.exit(0)
 
    # defaults that can be overwritten by arguments
 
    tmpl_stored_args = {
 
        'http_server': 'waitress',
 
        'lang': 'en',
 
        'database_engine': 'sqlite',
 
        'host': '127.0.0.1',
 
        'port': 5000,
 
        'error_aggregation_service': None,
 
    }
 
    if other:
 
        # parse arguments, we assume only first is correct
 
        kwargs = {}
 
        for el in _escape_split(other[0], ','):
 
            kv = _escape_split(el, '=')
 
            if len(kv) == 2:
 
                k, v = kv
 
                kwargs[k] = v
 
        # update our template stored args
 
        tmpl_stored_args.update(kwargs)
 

	
 
    # use default that cannot be replaced
 
    tmpl_stored_args.update({
 
        'uuid': lambda: uuid.uuid4().hex,
 
        'here': os.path.abspath(os.curdir),
 
    })
 
    if args.show_defaults:
 
        for k,v in tmpl_stored_args.iteritems():
 
            print '%s=%s' % (k, v)
 
        sys.exit(0)
 
    try:
 
        # built in template
 
        tmpl_file = os.path.join(here, TMPL)
 
        if args.template:
 
            tmpl_file = args.template
 

	
 
        with open(tmpl_file, 'rb') as f:
 
            tmpl_data = f.read()
 
            if args.raw:
 
                tmpl = tmpl_data
 
            else:
 
                tmpl = Template(tmpl_data).render(**tmpl_stored_args)
 
        with open(args.filename, 'wb') as f:
 
            f.write(tmpl)
 
        print 'Wrote new config file in %s' % (os.path.abspath(args.filename))
 

	
 
    except Exception:
 
        from mako import exceptions
 
        print exceptions.text_error_template().render()
 

	
 
def main(argv=None):
 
    """
 
    Main execution function for cli
 

	
 
    :param argv:
 
    """
 
    if argv is None:
 
        argv = sys.argv
 

	
 
    try:
 
        return _run(argv)
 
    except Exception:
 
        raise
 
    return _run(argv)
 

	
 

	
 
if __name__ == '__main__':
 
    sys.exit(main(sys.argv))
kallithea/controllers/admin/admin.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.controllers.admin.admin
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Controller for Admin panel of Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 7, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 

	
 
import logging
 

	
 
from pylons import request, tmpl_context as c, url
 
from sqlalchemy.orm import joinedload
 
from whoosh.qparser.default import QueryParser
 
from whoosh.qparser.dateparse import DateParserPlugin
 
from whoosh import query
 
from sqlalchemy.sql.expression import or_, and_, func
 

	
 
from kallithea.model.db import UserLog
 
from kallithea.lib.auth import LoginRequired, HasPermissionAllDecorator
 
from kallithea.lib.base import BaseController, render
 
from kallithea.lib.utils2 import safe_int, remove_prefix, remove_suffix
 
from kallithea.lib.indexers import JOURNAL_SCHEMA
 
from kallithea.lib.helpers import Page
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
def _journal_filter(user_log, search_term):
 
    """
 
    Filters sqlalchemy user_log based on search_term with whoosh Query language
 
    http://packages.python.org/Whoosh/querylang.html
 

	
 
    :param user_log:
 
    :param search_term:
 
    """
 
    log.debug('Initial search term: %r' % search_term)
 
    qry = None
 
    if search_term:
 
        qp = QueryParser('repository', schema=JOURNAL_SCHEMA)
 
        qp.add_plugin(DateParserPlugin())
 
        qry = qp.parse(unicode(search_term))
 
        log.debug('Filtering using parsed query %r' % qry)
 

	
 
    def wildcard_handler(col, wc_term):
 
        if wc_term.startswith('*') and not wc_term.endswith('*'):
 
            #postfix == endswith
 
            wc_term = remove_prefix(wc_term, prefix='*')
 
            return func.lower(col).endswith(wc_term)
 
        elif wc_term.startswith('*') and wc_term.endswith('*'):
 
            #wildcard == ilike
 
            wc_term = remove_prefix(wc_term, prefix='*')
 
            wc_term = remove_suffix(wc_term, suffix='*')
 
            return func.lower(col).contains(wc_term)
 

	
 
    def get_filterion(field, val, term):
 

	
 
        if field == 'repository':
 
            field = getattr(UserLog, 'repository_name')
 
        elif field == 'ip':
 
            field = getattr(UserLog, 'user_ip')
 
        elif field == 'date':
 
            field = getattr(UserLog, 'action_date')
 
        elif field == 'username':
 
            field = getattr(UserLog, 'username')
 
        else:
 
            field = getattr(UserLog, field)
 
        log.debug('filter field: %s val=>%s' % (field, val))
 

	
 
        #sql filtering
 
        if isinstance(term, query.Wildcard):
 
            return wildcard_handler(field, val)
 
        elif isinstance(term, query.Prefix):
 
            return func.lower(field).startswith(func.lower(val))
 
        elif isinstance(term, query.DateRange):
 
            return and_(field >= val[0], field <= val[1])
 
        return func.lower(field) == func.lower(val)
 

	
 
    if isinstance(qry, (query.And, query.Term, query.Prefix, query.Wildcard,
 
                        query.DateRange)):
 
        if not isinstance(qry, query.And):
 
            qry = [qry]
 
        for term in qry:
 
            field = term.fieldname
 
            val = (term.text if not isinstance(term, query.DateRange)
 
                   else [term.startdate, term.enddate])
 
            user_log = user_log.filter(get_filterion(field, val, term))
 
    elif isinstance(qry, query.Or):
 
        filters = []
 
        for term in qry:
 
            field = term.fieldname
 
            val = (term.text if not isinstance(term, query.DateRange)
 
                   else [term.startdate, term.enddate])
 
            filters.append(get_filterion(field, val, term))
 
        user_log = user_log.filter(or_(*filters))
 

	
 
    return user_log
 

	
 

	
 
class AdminController(BaseController):
 

	
 
    @LoginRequired()
 
    def __before__(self):
 
        super(AdminController, self).__before__()
 

	
 
    @HasPermissionAllDecorator('hg.admin')
 
    def index(self):
 
        users_log = UserLog.query()\
 
                .options(joinedload(UserLog.user))\
 
                .options(joinedload(UserLog.repository))
 

	
 
        #FILTERING
 
        c.search_term = request.GET.get('filter')
 
        try:
 
            users_log = _journal_filter(users_log, c.search_term)
 
        except Exception:
 
            # we want this to crash for now
 
            raise
 
        users_log = _journal_filter(users_log, c.search_term)
 

	
 
        users_log = users_log.order_by(UserLog.action_date.desc())
 

	
 
        p = safe_int(request.GET.get('page', 1), 1)
 

	
 
        def url_generator(**kw):
 
            return url.current(filter=c.search_term, **kw)
 

	
 
        c.users_log = Page(users_log, page=p, items_per_page=10, url=url_generator)
 
        c.log_data = render('admin/admin_log.html')
 

	
 
        if request.environ.get('HTTP_X_PARTIAL_XHR'):
 
            return c.log_data
 
        return render('admin/admin.html')
kallithea/controllers/changeset.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.controllers.changeset
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
changeset controller for pylons showoing changes beetween
 
revisions
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 25, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
import logging
 
import traceback
 
from collections import defaultdict
 
from webob.exc import HTTPForbidden, HTTPBadRequest, HTTPNotFound
 

	
 
from pylons import tmpl_context as c, request, response
 
from pylons.i18n.translation import _
 
from pylons.controllers.util import redirect
 
from kallithea.lib.utils import jsonify
 

	
 
from kallithea.lib.vcs.exceptions import RepositoryError, \
 
    ChangesetDoesNotExistError
 

	
 
from kallithea.lib.compat import json
 
import kallithea.lib.helpers as h
 
from kallithea.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator,\
 
    NotAnonymous
 
from kallithea.lib.base import BaseRepoController, render
 
from kallithea.lib.utils import action_logger
 
from kallithea.lib.compat import OrderedDict
 
from kallithea.lib import diffs
 
from kallithea.model.db import ChangesetComment, ChangesetStatus
 
from kallithea.model.comment import ChangesetCommentsModel
 
from kallithea.model.changeset_status import ChangesetStatusModel
 
from kallithea.model.meta import Session
 
from kallithea.model.repo import RepoModel
 
from kallithea.lib.diffs import LimitedDiffContainer
 
from kallithea.lib.exceptions import StatusChangeOnClosedPullRequestError
 
from kallithea.lib.vcs.backends.base import EmptyChangeset
 
from kallithea.lib.utils2 import safe_unicode
 
from kallithea.lib.graphmod import graph_data
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
def _update_with_GET(params, GET):
 
    for k in ['diff1', 'diff2', 'diff']:
 
        params[k] += GET.getall(k)
 

	
 

	
 
def anchor_url(revision, path, GET):
 
    fid = h.FID(revision, path)
 
    return h.url.current(anchor=fid, **dict(GET))
 

	
 

	
 
def get_ignore_ws(fid, GET):
 
    ig_ws_global = GET.get('ignorews')
 
    ig_ws = filter(lambda k: k.startswith('WS'), GET.getall(fid))
 
    if ig_ws:
 
        try:
 
            return int(ig_ws[0].split(':')[-1])
 
        except Exception:
 
            pass
 
        return int(ig_ws[0].split(':')[-1])
 
    return ig_ws_global
 

	
 

	
 
def _ignorews_url(GET, fileid=None):
 
    fileid = str(fileid) if fileid else None
 
    params = defaultdict(list)
 
    _update_with_GET(params, GET)
 
    lbl = _('Show whitespace')
 
    ig_ws = get_ignore_ws(fileid, GET)
 
    ln_ctx = get_line_ctx(fileid, GET)
 
    # global option
 
    if fileid is None:
 
        if ig_ws is None:
 
            params['ignorews'] += [1]
 
            lbl = _('Ignore whitespace')
 
        ctx_key = 'context'
 
        ctx_val = ln_ctx
 
    # per file options
 
    else:
 
        if ig_ws is None:
 
            params[fileid] += ['WS:1']
 
            lbl = _('Ignore whitespace')
 

	
 
        ctx_key = fileid
 
        ctx_val = 'C:%s' % ln_ctx
 
    # if we have passed in ln_ctx pass it along to our params
 
    if ln_ctx:
 
        params[ctx_key] += [ctx_val]
 

	
 
    params['anchor'] = fileid
 
    icon = h.literal('<i class="icon-strike"></i>')
 
    return h.link_to(icon, h.url.current(**params), title=lbl, class_='tooltip')
 

	
 

	
 
def get_line_ctx(fid, GET):
 
    ln_ctx_global = GET.get('context')
 
    if fid:
 
        ln_ctx = filter(lambda k: k.startswith('C'), GET.getall(fid))
 
    else:
 
        _ln_ctx = filter(lambda k: k.startswith('C'), GET)
 
        ln_ctx = GET.get(_ln_ctx[0]) if _ln_ctx  else ln_ctx_global
 
        if ln_ctx:
 
            ln_ctx = [ln_ctx]
 

	
 
    if ln_ctx:
 
        retval = ln_ctx[0].split(':')[-1]
 
    else:
 
        retval = ln_ctx_global
 

	
 
    try:
 
        return int(retval)
 
    except Exception:
 
        return 3
 

	
 

	
 
def _context_url(GET, fileid=None):
 
    """
 
    Generates url for context lines
 

	
 
    :param fileid:
 
    """
 

	
 
    fileid = str(fileid) if fileid else None
 
    ig_ws = get_ignore_ws(fileid, GET)
 
    ln_ctx = (get_line_ctx(fileid, GET) or 3) * 2
 

	
 
    params = defaultdict(list)
 
    _update_with_GET(params, GET)
 

	
 
    # global option
 
    if fileid is None:
 
        if ln_ctx > 0:
 
            params['context'] += [ln_ctx]
 

	
 
        if ig_ws:
 
            ig_ws_key = 'ignorews'
 
            ig_ws_val = 1
 

	
 
    # per file option
 
    else:
 
        params[fileid] += ['C:%s' % ln_ctx]
 
        ig_ws_key = fileid
 
        ig_ws_val = 'WS:%s' % 1
 

	
 
    if ig_ws:
 
        params[ig_ws_key] += [ig_ws_val]
 

	
 
    lbl = _('increase diff context to %(num)s lines') % {'num': ln_ctx}
 

	
 
    params['anchor'] = fileid
 
    icon = h.literal('<i class="icon-sort"></i>')
 
    return h.link_to(icon, h.url.current(**params), title=lbl, class_='tooltip')
 

	
 

	
 
class ChangesetController(BaseRepoController):
 

	
 
    def __before__(self):
 
        super(ChangesetController, self).__before__()
 
        c.affected_files_cut_off = 60
 

	
 
    def __load_data(self):
 
        repo_model = RepoModel()
 
        c.users_array = repo_model.get_users_js()
 
        c.user_groups_array = repo_model.get_user_groups_js()
 

	
 
    def _index(self, revision, method):
 
        c.anchor_url = anchor_url
 
        c.ignorews_url = _ignorews_url
 
        c.context_url = _context_url
 
        c.fulldiff = fulldiff = request.GET.get('fulldiff')
 
        #get ranges of revisions if preset
 
        rev_range = revision.split('...')[:2]
 
        enable_comments = True
 
        c.cs_repo = c.db_repo
 
        try:
 
            if len(rev_range) == 2:
 
                enable_comments = False
 
                rev_start = rev_range[0]
 
                rev_end = rev_range[1]
 
                rev_ranges = c.db_repo_scm_instance.get_changesets(start=rev_start,
 
                                                             end=rev_end)
 
            else:
 
                rev_ranges = [c.db_repo_scm_instance.get_changeset(revision)]
 

	
 
            c.cs_ranges = list(rev_ranges)
 
            if not c.cs_ranges:
 
                raise RepositoryError('Changeset range returned empty result')
 

	
 
        except(ChangesetDoesNotExistError,), e:
 
            log.error(traceback.format_exc())
 
            msg = _('Such revision does not exist for this repository')
 
            h.flash(msg, category='error')
 
            raise HTTPNotFound()
 

	
 
        c.changes = OrderedDict()
 

	
 
        c.lines_added = 0  # count of lines added
 
        c.lines_deleted = 0  # count of lines removes
 

	
 
        c.changeset_statuses = ChangesetStatus.STATUSES
 
        comments = dict()
 
        c.statuses = []
 
        c.inline_comments = []
 
        c.inline_cnt = 0
 

	
 
        # Iterate over ranges (default changeset view is always one changeset)
 
        for changeset in c.cs_ranges:
 
            inlines = []
 
            if method == 'show':
 
                c.statuses.extend([ChangesetStatusModel().get_status(
 
                            c.db_repo.repo_id, changeset.raw_id)])
 

	
 
                # Changeset comments
 
                comments.update((com.comment_id, com)
 
                                for com in ChangesetCommentsModel()
 
                                .get_comments(c.db_repo.repo_id,
 
                                              revision=changeset.raw_id))
 

	
 
                # Status change comments - mostly from pull requests
 
                comments.update((st.changeset_comment_id, st.comment)
 
                                for st in ChangesetStatusModel()
 
                                .get_statuses(c.db_repo.repo_id,
 
                                              changeset.raw_id, with_revisions=True))
 

	
 
                inlines = ChangesetCommentsModel()\
 
                            .get_inline_comments(c.db_repo.repo_id,
 
                                                 revision=changeset.raw_id)
 
                c.inline_comments.extend(inlines)
 

	
 
            c.changes[changeset.raw_id] = []
 

	
 
            cs2 = changeset.raw_id
 
            cs1 = changeset.parents[0].raw_id if changeset.parents else EmptyChangeset().raw_id
 
            context_lcl = get_line_ctx('', request.GET)
 
            ign_whitespace_lcl = ign_whitespace_lcl = get_ignore_ws('', request.GET)
 

	
 
            _diff = c.db_repo_scm_instance.get_diff(cs1, cs2,
 
                ignore_whitespace=ign_whitespace_lcl, context=context_lcl)
 
            diff_limit = self.cut_off_limit if not fulldiff else None
 
            diff_processor = diffs.DiffProcessor(_diff,
 
                                                 vcs=c.db_repo_scm_instance.alias,
 
                                                 format='gitdiff',
 
                                                 diff_limit=diff_limit)
 
            cs_changes = OrderedDict()
 
            if method == 'show':
 
                _parsed = diff_processor.prepare()
 
                c.limited_diff = False
 
                if isinstance(_parsed, LimitedDiffContainer):
 
                    c.limited_diff = True
 
                for f in _parsed:
 
                    st = f['stats']
 
                    c.lines_added += st['added']
 
                    c.lines_deleted += st['deleted']
 
                    fid = h.FID(changeset.raw_id, f['filename'])
 
                    diff = diff_processor.as_html(enable_comments=enable_comments,
 
                                                  parsed_lines=[f])
 
                    cs_changes[fid] = [cs1, cs2, f['operation'], f['filename'],
 
                                       diff, st]
 
            else:
 
                # downloads/raw we only need RAW diff nothing else
 
                diff = diff_processor.as_raw()
 
                cs_changes[''] = [None, None, None, None, diff, None]
 
            c.changes[changeset.raw_id] = cs_changes
 

	
 
        #sort comments in creation order
 
        c.comments = [com for com_id, com in sorted(comments.items())]
 

	
 
        # count inline comments
 
        for __, lines in c.inline_comments:
 
            for comments in lines.values():
 
                c.inline_cnt += len(comments)
 

	
 
        if len(c.cs_ranges) == 1:
 
            c.changeset = c.cs_ranges[0]
 
            c.parent_tmpl = ''.join(['# Parent  %s\n' % x.raw_id
 
                                     for x in c.changeset.parents])
 
        if method == 'download':
 
            response.content_type = 'text/plain'
 
            response.content_disposition = 'attachment; filename=%s.diff' \
 
                                            % revision[:12]
 
            return diff
 
        elif method == 'patch':
 
            response.content_type = 'text/plain'
 
            c.diff = safe_unicode(diff)
 
            return render('changeset/patch_changeset.html')
 
        elif method == 'raw':
 
            response.content_type = 'text/plain'
 
            return diff
 
        elif method == 'show':
 
            self.__load_data()
 
            if len(c.cs_ranges) == 1:
 
                return render('changeset/changeset.html')
 
            else:
 
                c.cs_ranges_org = None
 
                c.cs_comments = {}
 
                revs = [ctx.revision for ctx in reversed(c.cs_ranges)]
 
                c.jsdata = json.dumps(graph_data(c.db_repo_scm_instance, revs))
 
                return render('changeset/changeset_range.html')
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def index(self, revision, method='show'):
 
        return self._index(revision, method=method)
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def changeset_raw(self, revision):
 
        return self._index(revision, method='raw')
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def changeset_patch(self, revision):
 
        return self._index(revision, method='patch')
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def changeset_download(self, revision):
 
        return self._index(revision, method='download')
 

	
 
    @LoginRequired()
 
    @NotAnonymous()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    @jsonify
 
    def comment(self, repo_name, revision):
 
        status = request.POST.get('changeset_status')
 
        text = request.POST.get('text', '').strip() or _('No comments.')
 

	
 
        c.co = comm = ChangesetCommentsModel().create(
 
            text=text,
 
            repo=c.db_repo.repo_id,
 
            user=c.authuser.user_id,
 
            revision=revision,
 
            f_path=request.POST.get('f_path'),
 
            line_no=request.POST.get('line'),
 
            status_change=(ChangesetStatus.get_status_lbl(status)
 
                           if status else None)
 
        )
 

	
 
        # get status if set !
 
        if status:
 
            # if latest status was from pull request and it's closed
 
            # disallow changing status !
 
            # dont_allow_on_closed_pull_request = True !
 

	
 
            try:
 
                ChangesetStatusModel().set_status(
 
                    c.db_repo.repo_id,
 
                    status,
 
                    c.authuser.user_id,
 
                    comm,
 
                    revision=revision,
 
                    dont_allow_on_closed_pull_request=True
 
                )
 
            except StatusChangeOnClosedPullRequestError:
 
                log.error(traceback.format_exc())
 
                msg = _('Changing status on a changeset associated with '
 
                        'a closed pull request is not allowed')
 
                h.flash(msg, category='warning')
 
                return redirect(h.url('changeset_home', repo_name=repo_name,
 
                                      revision=revision))
 
        action_logger(self.authuser,
 
                      'user_commented_revision:%s' % revision,
 
                      c.db_repo, self.ip_addr, self.sa)
 

	
 
        Session().commit()
 

	
 
        if not request.environ.get('HTTP_X_PARTIAL_XHR'):
 
            return redirect(h.url('changeset_home', repo_name=repo_name,
 
                                  revision=revision))
 
        #only ajax below
 
        data = {
 
           'target_id': h.safeid(h.safe_unicode(request.POST.get('f_path'))),
 
        }
 
        if comm:
 
            data.update(comm.get_dict())
 
            data.update({'rendered_text':
 
                         render('changeset/changeset_comment_block.html')})
 

	
 
        return data
 

	
 
    @LoginRequired()
 
    @NotAnonymous()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def preview_comment(self):
 
        if not request.environ.get('HTTP_X_PARTIAL_XHR'):
 
            raise HTTPBadRequest()
 
        text = request.POST.get('text')
 
        if text:
 
            return h.rst_w_mentions(text)
 
        return ''
 

	
 
    @LoginRequired()
 
    @NotAnonymous()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    @jsonify
 
    def delete_comment(self, repo_name, comment_id):
 
        co = ChangesetComment.get(comment_id)
 
        if not co:
 
            raise HTTPBadRequest()
 
        owner = co.author.user_id == c.authuser.user_id
 
        repo_admin = h.HasRepoPermissionAny('repository.admin')
 
        if h.HasPermissionAny('hg.admin')() or repo_admin or owner:
 
            ChangesetCommentsModel().delete(comment=co)
 
            Session().commit()
 
            return True
 
        else:
 
            raise HTTPForbidden()
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    @jsonify
 
    def changeset_info(self, repo_name, revision):
 
        if request.is_xhr:
 
            try:
 
                return c.db_repo_scm_instance.get_changeset(revision)
 
            except ChangesetDoesNotExistError, e:
 
                return EmptyChangeset(message=str(e))
 
        else:
 
            raise HTTPBadRequest()
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    @jsonify
 
    def changeset_children(self, repo_name, revision):
 
        if request.is_xhr:
 
            changeset = c.db_repo_scm_instance.get_changeset(revision)
 
            result = {"results": []}
 
            if changeset.children:
 
                result = {"results": changeset.children}
 
            return result
 
        else:
 
            raise HTTPBadRequest()
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    @jsonify
 
    def changeset_parents(self, repo_name, revision):
 
        if request.is_xhr:
 
            changeset = c.db_repo_scm_instance.get_changeset(revision)
 
            result = {"results": []}
 
            if changeset.parents:
 
                result = {"results": changeset.parents}
 
            return result
 
        else:
 
            raise HTTPBadRequest()
kallithea/controllers/error.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.controllers.error
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Kallithea error controller
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Dec 8, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
import os
 
import cgi
 
import logging
 
import paste.fileapp
 

	
 
from pylons import tmpl_context as c, request, config
 
from pylons.i18n.translation import _
 
from pylons.middleware import  media_path
 

	
 
from kallithea.lib.base import BaseController, render
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class ErrorController(BaseController):
 
    """Generates error documents as and when they are required.
 

	
 
    The ErrorDocuments middleware forwards to ErrorController when error
 
    related status codes are returned from the application.
 

	
 
    This behavior can be altered by changing the parameters to the
 
    ErrorDocuments middleware in your config/middleware.py file.
 
    """
 

	
 
    def __before__(self):
 
        #disable all base actions since we don't need them here
 
        pass
 

	
 
    def document(self):
 
        resp = request.environ.get('pylons.original_response')
 
        c.site_name = config.get('title')
 

	
 
        log.debug('### %s ###' % resp.status)
 

	
 
        e = request.environ
 
        c.serv_p = r'%(protocol)s://%(host)s/' \
 
                                    % {'protocol': e.get('wsgi.url_scheme'),
 
                                       'host': e.get('HTTP_HOST'), }
 

	
 
        c.error_message = cgi.escape(request.GET.get('code', str(resp.status)))
 
        c.error_explanation = self.get_error_explanation(resp.status_int)
 

	
 
        return render('/errors/error_document.html')
 

	
 
    def img(self, id):
 
        """Serve Pylons' stock images"""
 
        return self._serve_file(os.path.join(media_path, 'img', id))
 

	
 
    def style(self, id):
 
        """Serve Pylons' stock stylesheets"""
 
        return self._serve_file(os.path.join(media_path, 'style', id))
 

	
 
    def _serve_file(self, path):
 
        """Call Paste's FileApp (a WSGI application) to serve the file
 
        at the specified path
 
        """
 
        fapp = paste.fileapp.FileApp(path)
 
        return fapp(request.environ, self.start_response)
 

	
 
    def get_error_explanation(self, code):
 
        """ get the error explanations of int codes
 
            [400, 401, 403, 404, 500]"""
 
        try:
 
            code = int(code)
 
        except Exception:
 
        except ValueError:
 
            code = 500
 

	
 
        if code == 400:
 
            return _('The request could not be understood by the server'
 
                     ' due to malformed syntax.')
 
        if code == 401:
 
            return _('Unauthorized access to resource')
 
        if code == 403:
 
            return _("You don't have permission to view this page")
 
        if code == 404:
 
            return _('The resource could not be found')
 
        if code == 500:
 
            return _('The server encountered an unexpected condition'
 
                     ' which prevented it from fulfilling the request.')
kallithea/controllers/journal.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.controllers.journal
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Journal controller for pylons
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Nov 21, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 

	
 
"""
 

	
 
import logging
 
import traceback
 
from itertools import groupby
 

	
 
from sqlalchemy import or_
 
from sqlalchemy.orm import joinedload
 
from sqlalchemy.sql.expression import func
 

	
 
from webhelpers.feedgenerator import Atom1Feed, Rss201rev2Feed
 

	
 
from webob.exc import HTTPBadRequest
 
from pylons import request, tmpl_context as c, response, url
 
from pylons.i18n.translation import _
 

	
 
from kallithea.controllers.admin.admin import _journal_filter
 
from kallithea.model.db import UserLog, UserFollowing, Repository, User
 
from kallithea.model.meta import Session
 
from kallithea.model.repo import RepoModel
 
import kallithea.lib.helpers as h
 
from kallithea.lib.helpers import Page
 
from kallithea.lib.auth import LoginRequired, NotAnonymous
 
from kallithea.lib.base import BaseController, render
 
from kallithea.lib.utils2 import safe_int, AttributeDict
 
from kallithea.lib.compat import json
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class JournalController(BaseController):
 

	
 
    def __before__(self):
 
        super(JournalController, self).__before__()
 
        self.language = 'en-us'
 
        self.ttl = "5"
 
        self.feed_nr = 20
 
        c.search_term = request.GET.get('filter')
 

	
 
    def _get_daily_aggregate(self, journal):
 
        groups = []
 
        for k, g in groupby(journal, lambda x: x.action_as_day):
 
            user_group = []
 
            #groupby username if it's a present value, else fallback to journal username
 
            for _unused, g2 in groupby(list(g), lambda x: x.user.username if x.user else x.username):
 
                l = list(g2)
 
                user_group.append((l[0].user, l))
 

	
 
            groups.append((k, user_group,))
 

	
 
        return groups
 

	
 
    def _get_journal_data(self, following_repos):
 
        repo_ids = [x.follows_repository.repo_id for x in following_repos
 
                    if x.follows_repository is not None]
 
        user_ids = [x.follows_user.user_id for x in following_repos
 
                    if x.follows_user is not None]
 

	
 
        filtering_criterion = None
 

	
 
        if repo_ids and user_ids:
 
            filtering_criterion = or_(UserLog.repository_id.in_(repo_ids),
 
                        UserLog.user_id.in_(user_ids))
 
        if repo_ids and not user_ids:
 
            filtering_criterion = UserLog.repository_id.in_(repo_ids)
 
        if not repo_ids and user_ids:
 
            filtering_criterion = UserLog.user_id.in_(user_ids)
 
        if filtering_criterion is not None:
 
            journal = self.sa.query(UserLog)\
 
                .options(joinedload(UserLog.user))\
 
                .options(joinedload(UserLog.repository))
 
            #filter
 
            try:
 
                journal = _journal_filter(journal, c.search_term)
 
            except Exception:
 
                # we want this to crash for now
 
                raise
 
            journal = _journal_filter(journal, c.search_term)
 
            journal = journal.filter(filtering_criterion)\
 
                        .order_by(UserLog.action_date.desc())
 
        else:
 
            journal = []
 

	
 
        return journal
 

	
 
    def _atom_feed(self, repos, public=True):
 
        journal = self._get_journal_data(repos)
 
        if public:
 
            _link = h.canonical_url('public_journal_atom')
 
            _desc = '%s %s %s' % (c.site_name, _('public journal'),
 
                                  'atom feed')
 
        else:
 
            _link = h.canonical_url('journal_atom')
 
            _desc = '%s %s %s' % (c.site_name, _('journal'), 'atom feed')
 

	
 
        feed = Atom1Feed(title=_desc,
 
                         link=_link,
 
                         description=_desc,
 
                         language=self.language,
 
                         ttl=self.ttl)
 

	
 
        for entry in journal[:self.feed_nr]:
 
            user = entry.user
 
            if user is None:
 
                #fix deleted users
 
                user = AttributeDict({'short_contact': entry.username,
 
                                      'email': '',
 
                                      'full_contact': ''})
 
            action, action_extra, ico = h.action_parser(entry, feed=True)
 
            title = "%s - %s %s" % (user.short_contact, action(),
 
                                    entry.repository.repo_name)
 
            desc = action_extra()
 
            _url = None
 
            if entry.repository is not None:
 
                _url = h.canonical_url('changelog_home',
 
                           repo_name=entry.repository.repo_name)
 

	
 
            feed.add_item(title=title,
 
                          pubdate=entry.action_date,
 
                          link=_url or h.canonical_url(''),
 
                          author_email=user.email,
 
                          author_name=user.full_contact,
 
                          description=desc)
 

	
 
        response.content_type = feed.mime_type
 
        return feed.writeString('utf-8')
 

	
 
    def _rss_feed(self, repos, public=True):
 
        journal = self._get_journal_data(repos)
 
        if public:
 
            _link = h.canonical_url('public_journal_atom')
 
            _desc = '%s %s %s' % (c.site_name, _('public journal'),
 
                                  'rss feed')
 
        else:
 
            _link = h.canonical_url('journal_atom')
 
            _desc = '%s %s %s' % (c.site_name, _('journal'), 'rss feed')
 

	
 
        feed = Rss201rev2Feed(title=_desc,
 
                         link=_link,
 
                         description=_desc,
 
                         language=self.language,
 
                         ttl=self.ttl)
 

	
 
        for entry in journal[:self.feed_nr]:
 
            user = entry.user
 
            if user is None:
 
                #fix deleted users
 
                user = AttributeDict({'short_contact': entry.username,
 
                                      'email': '',
 
                                      'full_contact': ''})
 
            action, action_extra, ico = h.action_parser(entry, feed=True)
 
            title = "%s - %s %s" % (user.short_contact, action(),
 
                                    entry.repository.repo_name)
 
            desc = action_extra()
 
            _url = None
 
            if entry.repository is not None:
 
                _url = h.canonical_url('changelog_home',
 
                           repo_name=entry.repository.repo_name)
 

	
 
            feed.add_item(title=title,
 
                          pubdate=entry.action_date,
 
                          link=_url or h.canonical_url(''),
 
                          author_email=user.email,
 
                          author_name=user.full_contact,
 
                          description=desc)
 

	
 
        response.content_type = feed.mime_type
 
        return feed.writeString('utf-8')
 

	
 
    @LoginRequired()
 
    @NotAnonymous()
 
    def index(self):
 
        # Return a rendered template
 
        p = safe_int(request.GET.get('page', 1), 1)
 
        c.user = User.get(self.authuser.user_id)
 
        c.following = self.sa.query(UserFollowing)\
 
            .filter(UserFollowing.user_id == self.authuser.user_id)\
 
            .options(joinedload(UserFollowing.follows_repository))\
 
            .all()
 

	
 
        journal = self._get_journal_data(c.following)
 

	
 
        def url_generator(**kw):
 
            return url.current(filter=c.search_term, **kw)
 

	
 
        c.journal_pager = Page(journal, page=p, items_per_page=20, url=url_generator)
 
        c.journal_day_aggreagate = self._get_daily_aggregate(c.journal_pager)
 

	
 
        c.journal_data = render('journal/journal_data.html')
 
        if request.environ.get('HTTP_X_PARTIAL_XHR'):
 
            return c.journal_data
 

	
 
        repos_list = Session().query(Repository)\
 
                     .filter(Repository.user_id ==
 
                             self.authuser.user_id)\
 
                     .order_by(func.lower(Repository.repo_name)).all()
 

	
 
        repos_data = RepoModel().get_repos_as_dict(repos_list=repos_list,
 
                                                   admin=True)
 
        #json used to render the grid
 
        c.data = json.dumps(repos_data)
 

	
 
        watched_repos_data = []
 

	
 
        ## watched repos
 
        _render = RepoModel._render_datatable
 

	
 
        def quick_menu(repo_name):
 
            return _render('quick_menu', repo_name)
 

	
 
        def repo_lnk(name, rtype, rstate, private, fork_of):
 
            return _render('repo_name', name, rtype, rstate, private, fork_of,
 
                           short_name=False, admin=False)
 

	
 
        def last_rev(repo_name, cs_cache):
 
            return _render('revision', repo_name, cs_cache.get('revision'),
 
                           cs_cache.get('raw_id'), cs_cache.get('author'),
 
                           cs_cache.get('message'))
 

	
 
        def desc(desc):
 
            from pylons import tmpl_context as c
 
            if c.visual.stylify_metatags:
 
                return h.urlify_text(h.desc_stylize(h.truncate(desc, 60)))
 
            else:
 
                return h.urlify_text(h.truncate(desc, 60))
 

	
 
        def repo_actions(repo_name):
 
            return _render('repo_actions', repo_name)
 

	
 
        def owner_actions(user_id, username):
 
            return _render('user_name', user_id, username)
 

	
 
        def toogle_follow(repo_id):
 
            return  _render('toggle_follow', repo_id)
 

	
 
        for entry in c.following:
 
            repo = entry.follows_repository
 
            cs_cache = repo.changeset_cache
 
            row = {
 
                "menu": quick_menu(repo.repo_name),
 
                "raw_name": repo.repo_name.lower(),
 
                "name": repo_lnk(repo.repo_name, repo.repo_type,
 
                                 repo.repo_state, repo.private, repo.fork),
 
                "last_changeset": last_rev(repo.repo_name, cs_cache),
 
                "last_rev_raw": cs_cache.get('revision'),
 
                "action": toogle_follow(repo.repo_id)
 
            }
 

	
 
            watched_repos_data.append(row)
 

	
 
        c.watched_data = json.dumps({
 
            "totalRecords": len(c.following),
 
            "startIndex": 0,
 
            "sort": "name",
 
            "dir": "asc",
 
            "records": watched_repos_data
 
        })
 
        return render('journal/journal.html')
 

	
 
    @LoginRequired(api_access=True)
 
    @NotAnonymous()
 
    def journal_atom(self):
 
        """
 
        Produce an atom-1.0 feed via feedgenerator module
 
        """
 
        following = self.sa.query(UserFollowing)\
 
            .filter(UserFollowing.user_id == self.authuser.user_id)\
 
            .options(joinedload(UserFollowing.follows_repository))\
 
            .all()
 
        return self._atom_feed(following, public=False)
 

	
 
    @LoginRequired(api_access=True)
 
    @NotAnonymous()
 
    def journal_rss(self):
 
        """
 
        Produce an rss feed via feedgenerator module
 
        """
 
        following = self.sa.query(UserFollowing)\
 
            .filter(UserFollowing.user_id == self.authuser.user_id)\
 
            .options(joinedload(UserFollowing.follows_repository))\
 
            .all()
 
        return self._rss_feed(following, public=False)
 

	
 
    @LoginRequired()
 
    @NotAnonymous()
 
    def toggle_following(self):
 
        cur_token = request.POST.get('auth_token')
 
        token = h.get_token()
 
        if cur_token == token:
 

	
 
            user_id = request.POST.get('follows_user_id')
 
            if user_id:
 
                try:
 
                    self.scm_model.toggle_following_user(user_id,
 
                                                self.authuser.user_id)
 
                    Session.commit()
 
                    return 'ok'
 
                except Exception:
 
                    log.error(traceback.format_exc())
 
                    raise HTTPBadRequest()
 

	
 
            repo_id = request.POST.get('follows_repo_id')
 
            if repo_id:
 
                try:
 
                    self.scm_model.toggle_following_repo(repo_id,
 
                                                self.authuser.user_id)
 
                    Session.commit()
 
                    return 'ok'
 
                except Exception:
 
                    log.error(traceback.format_exc())
 
                    raise HTTPBadRequest()
 

	
 
        log.debug('token mismatch %s vs %s' % (cur_token, token))
 
        raise HTTPBadRequest()
 

	
 
    @LoginRequired()
 
    def public_journal(self):
 
        # Return a rendered template
 
        p = safe_int(request.GET.get('page', 1), 1)
 

	
 
        c.following = self.sa.query(UserFollowing)\
 
            .filter(UserFollowing.user_id == self.authuser.user_id)\
 
            .options(joinedload(UserFollowing.follows_repository))\
 
            .all()
 

	
 
        journal = self._get_journal_data(c.following)
 

	
 
        c.journal_pager = Page(journal, page=p, items_per_page=20)
 

	
 
        c.journal_day_aggreagate = self._get_daily_aggregate(c.journal_pager)
 

	
 
        c.journal_data = render('journal/journal_data.html')
 
        if request.environ.get('HTTP_X_PARTIAL_XHR'):
 
            return c.journal_data
 
        return render('journal/public_journal.html')
 

	
 
    @LoginRequired(api_access=True)
 
    def public_journal_atom(self):
 
        """
 
        Produce an atom-1.0 feed via feedgenerator module
 
        """
 
        c.following = self.sa.query(UserFollowing)\
 
            .filter(UserFollowing.user_id == self.authuser.user_id)\
 
            .options(joinedload(UserFollowing.follows_repository))\
 
            .all()
 

	
 
        return self._atom_feed(c.following)
 

	
 
    @LoginRequired(api_access=True)
 
    def public_journal_rss(self):
 
        """
 
        Produce an rss2 feed via feedgenerator module
 
        """
 
        c.following = self.sa.query(UserFollowing)\
 
            .filter(UserFollowing.user_id == self.authuser.user_id)\
 
            .options(joinedload(UserFollowing.follows_repository))\
 
            .all()
 

	
 
        return self._rss_feed(c.following)
kallithea/controllers/summary.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.controllers.summary
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Summary controller for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 18, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
import traceback
 
import calendar
 
import logging
 
from time import mktime
 
from datetime import timedelta, date
 

	
 
from pylons import tmpl_context as c, request
 
from pylons.i18n.translation import _
 
from webob.exc import HTTPBadRequest
 

	
 
from beaker.cache import cache_region, region_invalidate
 

	
 
from kallithea.lib.compat import product
 
from kallithea.lib.vcs.exceptions import ChangesetError, EmptyRepositoryError, \
 
    NodeDoesNotExistError
 
from kallithea.config.conf import ALL_READMES, ALL_EXTS, LANGUAGES_EXTENSIONS_MAP
 
from kallithea.model.db import Statistics, CacheInvalidation, User
 
from kallithea.lib.utils import jsonify
 
from kallithea.lib.utils2 import safe_str
 
from kallithea.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator,\
 
    NotAnonymous
 
from kallithea.lib.base import BaseRepoController, render
 
from kallithea.lib.vcs.backends.base import EmptyChangeset
 
from kallithea.lib.markup_renderer import MarkupRenderer
 
from kallithea.lib.celerylib import run_task
 
from kallithea.lib.celerylib.tasks import get_commits_stats
 
from kallithea.lib.compat import json
 
from kallithea.lib.vcs.nodes import FileNode
 
from kallithea.controllers.changelog import _load_changelog_summary
 

	
 
log = logging.getLogger(__name__)
 

	
 
README_FILES = [''.join([x[0][0], x[1][0]]) for x in
 
                    sorted(list(product(ALL_READMES, ALL_EXTS)),
 
                           key=lambda y:y[0][1] + y[1][1])]
 

	
 

	
 
class SummaryController(BaseRepoController):
 

	
 
    def __before__(self):
 
        super(SummaryController, self).__before__()
 

	
 
    def _get_download_links(self, repo):
 

	
 
        download_l = []
 

	
 
        branches_group = ([], _("Branches"))
 
        tags_group = ([], _("Tags"))
 

	
 
        for name, chs in c.db_repo_scm_instance.branches.items():
 
            #chs = chs.split(':')[-1]
 
            branches_group[0].append((chs, name),)
 
        download_l.append(branches_group)
 

	
 
        for name, chs in c.db_repo_scm_instance.tags.items():
 
            #chs = chs.split(':')[-1]
 
            tags_group[0].append((chs, name),)
 
        download_l.append(tags_group)
 

	
 
        return download_l
 

	
 
    def __get_readme_data(self, db_repo):
 
        repo_name = db_repo.repo_name
 
        log.debug('Looking for README file')
 

	
 
        @cache_region('long_term')
 
        def _get_readme_from_cache(key, kind):
 
            readme_data = None
 
            readme_file = None
 
            try:
 
                # gets the landing revision! or tip if fails
 
                cs = db_repo.get_landing_changeset()
 
                if isinstance(cs, EmptyChangeset):
 
                    raise EmptyRepositoryError()
 
                renderer = MarkupRenderer()
 
                for f in README_FILES:
 
                    try:
 
                        readme = cs.get_node(f)
 
                        if not isinstance(readme, FileNode):
 
                            continue
 
                        readme_file = f
 
                        log.debug('Found README file `%s` rendering...' %
 
                                  readme_file)
 
                        readme_data = renderer.render(readme.content,
 
                                                      filename=f)
 
                        break
 
                    except NodeDoesNotExistError:
 
                        continue
 
            except ChangesetError:
 
                log.error(traceback.format_exc())
 
                pass
 
            except EmptyRepositoryError:
 
                pass
 
            except Exception:
 
                log.error(traceback.format_exc())
 

	
 
            return readme_data, readme_file
 

	
 
        kind = 'README'
 
        valid = CacheInvalidation.test_and_set_valid(repo_name, kind)
 
        if not valid:
 
            region_invalidate(_get_readme_from_cache, None, repo_name, kind)
 
        return _get_readme_from_cache(repo_name, kind)
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def index(self, repo_name):
 
        _load_changelog_summary()
 

	
 
        username = ''
 
        if self.authuser.username != User.DEFAULT_USER:
 
            username = safe_str(self.authuser.username)
 

	
 
        _def_clone_uri = _def_clone_uri_by_id = c.clone_uri_tmpl
 
        if '{repo}' in _def_clone_uri:
 
            _def_clone_uri_by_id = _def_clone_uri.replace('{repo}', '_{repoid}')
 
        elif '{repoid}' in _def_clone_uri:
 
            _def_clone_uri_by_id = _def_clone_uri.replace('_{repoid}', '{repo}')
 

	
 
        c.clone_repo_url = c.db_repo.clone_url(user=username,
 
                                                uri_tmpl=_def_clone_uri)
 
        c.clone_repo_url_id = c.db_repo.clone_url(user=username,
 
                                                uri_tmpl=_def_clone_uri_by_id)
 

	
 
        if c.db_repo.enable_statistics:
 
            c.show_stats = True
 
        else:
 
            c.show_stats = False
 

	
 
        stats = self.sa.query(Statistics)\
 
            .filter(Statistics.repository == c.db_repo)\
 
            .scalar()
 

	
 
        c.stats_percentage = 0
 

	
 
        if stats and stats.languages:
 
            c.no_data = False is c.db_repo.enable_statistics
 
            lang_stats_d = json.loads(stats.languages)
 

	
 
            lang_stats = ((x, {"count": y,
 
                               "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
 
                          for x, y in lang_stats_d.items())
 

	
 
            c.trending_languages = json.dumps(
 
                sorted(lang_stats, reverse=True, key=lambda k: k[1])[:10]
 
            )
 
        else:
 
            c.no_data = True
 
            c.trending_languages = json.dumps({})
 

	
 
        c.enable_downloads = c.db_repo.enable_downloads
 
        c.readme_data, c.readme_file = \
 
            self.__get_readme_data(c.db_repo)
 
        return render('summary/summary.html')
 

	
 
    @LoginRequired()
 
    @NotAnonymous()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    @jsonify
 
    def repo_size(self, repo_name):
 
        if request.is_xhr:
 
            return c.db_repo._repo_size()
 
        else:
 
            raise HTTPBadRequest()
 

	
 
    @LoginRequired()
 
    @HasRepoPermissionAnyDecorator('repository.read', 'repository.write',
 
                                   'repository.admin')
 
    def statistics(self, repo_name):
 
        if c.db_repo.enable_statistics:
 
            c.show_stats = True
 
            c.no_data_msg = _('No data loaded yet')
 
        else:
 
            c.show_stats = False
 
            c.no_data_msg = _('Statistics are disabled for this repository')
 

	
 
        td = date.today() + timedelta(days=1)
 
        td_1m = td - timedelta(days=calendar.mdays[td.month])
 
        td_1y = td - timedelta(days=365)
 

	
 
        ts_min_m = mktime(td_1m.timetuple())
 
        ts_min_y = mktime(td_1y.timetuple())
 
        ts_max_y = mktime(td.timetuple())
 
        c.ts_min = ts_min_m
 
        c.ts_max = ts_max_y
 

	
 
        stats = self.sa.query(Statistics)\
 
            .filter(Statistics.repository == c.db_repo)\
 
            .scalar()
 
        if stats and stats.languages:
 
            c.no_data = False is c.db_repo.enable_statistics
 
            lang_stats_d = json.loads(stats.languages)
 
            c.commit_data = stats.commit_activity
 
            c.overview_data = stats.commit_activity_combined
 

	
 
            lang_stats = ((x, {"count": y,
 
                               "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
 
                          for x, y in lang_stats_d.items())
 

	
 
            c.trending_languages = json.dumps(
 
                sorted(lang_stats, reverse=True, key=lambda k: k[1])[:10]
 
            )
 
            last_rev = stats.stat_on_revision + 1
 
            c.repo_last_rev = c.db_repo_scm_instance.count()\
 
                if c.db_repo_scm_instance.revisions else 0
 
            if last_rev == 0 or c.repo_last_rev == 0:
 
                pass
 
            else:
 
                c.stats_percentage = '%.2f' % ((float((last_rev)) /
 
                                                c.repo_last_rev) * 100)
 
        else:
 
            c.commit_data = json.dumps({})
 
            c.overview_data = json.dumps([[ts_min_y, 0], [ts_max_y, 10]])
 
            c.trending_languages = json.dumps({})
 
            c.no_data = True
 

	
 
        recurse_limit = 500  # don't recurse more than 500 times when parsing
 
        run_task(get_commits_stats, c.db_repo.repo_name, ts_min_y,
 
                 ts_max_y, recurse_limit)
 
        return render('summary/statistics.html')
kallithea/lib/auth.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.lib.auth
 
~~~~~~~~~~~~~~~~~~
 

	
 
authentication and permission libraries
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 4, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 
from __future__ import with_statement
 
import time
 
import random
 
import logging
 
import traceback
 
import hashlib
 
import itertools
 
import collections
 

	
 
from tempfile import _RandomNameSequence
 
from decorator import decorator
 

	
 
from pylons import url, request
 
from pylons.controllers.util import abort, redirect
 
from pylons.i18n.translation import _
 
from sqlalchemy import or_
 
from sqlalchemy.orm.exc import ObjectDeletedError
 
from sqlalchemy.orm import joinedload
 

	
 
from kallithea import __platform__, is_windows, is_unix
 
from kallithea.lib.vcs.utils.lazy import LazyProperty
 
from kallithea.model import meta
 
from kallithea.model.meta import Session
 
from kallithea.model.user import UserModel
 
from kallithea.model.db import User, Repository, Permission, \
 
    UserToPerm, UserGroupRepoToPerm, UserGroupToPerm, UserGroupMember, \
 
    RepoGroup, UserGroupRepoGroupToPerm, UserIpMap, UserGroupUserGroupToPerm, \
 
    UserGroup, UserApiKeys
 

	
 
from kallithea.lib.utils2 import safe_unicode, aslist
 
from kallithea.lib.utils import get_repo_slug, get_repo_group_slug, \
 
    get_user_group_slug, conditional_cache
 
from kallithea.lib.caching_query import FromCache
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class PasswordGenerator(object):
 
    """
 
    This is a simple class for generating password from different sets of
 
    characters
 
    usage::
 

	
 
        passwd_gen = PasswordGenerator()
 
        #print 8-letter password containing only big and small letters
 
            of alphabet
 
        passwd_gen.gen_password(8, passwd_gen.ALPHABETS_BIG_SMALL)
 
    """
 
    ALPHABETS_NUM = r'''1234567890'''
 
    ALPHABETS_SMALL = r'''qwertyuiopasdfghjklzxcvbnm'''
 
    ALPHABETS_BIG = r'''QWERTYUIOPASDFGHJKLZXCVBNM'''
 
    ALPHABETS_SPECIAL = r'''`-=[]\;',./~!@#$%^&*()_+{}|:"<>?'''
 
    ALPHABETS_FULL = ALPHABETS_BIG + ALPHABETS_SMALL \
 
        + ALPHABETS_NUM + ALPHABETS_SPECIAL
 
    ALPHABETS_ALPHANUM = ALPHABETS_BIG + ALPHABETS_SMALL + ALPHABETS_NUM
 
    ALPHABETS_BIG_SMALL = ALPHABETS_BIG + ALPHABETS_SMALL
 
    ALPHABETS_ALPHANUM_BIG = ALPHABETS_BIG + ALPHABETS_NUM
 
    ALPHABETS_ALPHANUM_SMALL = ALPHABETS_SMALL + ALPHABETS_NUM
 

	
 
    def __init__(self, passwd=''):
 
        self.passwd = passwd
 

	
 
    def gen_password(self, length, type_=None):
 
        if type_ is None:
 
            type_ = self.ALPHABETS_FULL
 
        self.passwd = ''.join([random.choice(type_) for _ in xrange(length)])
 
        return self.passwd
 

	
 

	
 
class KallitheaCrypto(object):
 

	
 
    @classmethod
 
    def hash_string(cls, str_):
 
        """
 
        Cryptographic function used for password hashing based on pybcrypt
 
        or pycrypto in windows
 

	
 
        :param password: password to hash
 
        """
 
        if is_windows:
 
            from hashlib import sha256
 
            return sha256(str_).hexdigest()
 
        elif is_unix:
 
            import bcrypt
 
            return bcrypt.hashpw(str_, bcrypt.gensalt(10))
 
        else:
 
            raise Exception('Unknown or unsupported platform %s' \
 
                            % __platform__)
 

	
 
    @classmethod
 
    def hash_check(cls, password, hashed):
 
        """
 
        Checks matching password with it's hashed value, runs different
 
        implementation based on platform it runs on
 

	
 
        :param password: password
 
        :param hashed: password in hashed form
 
        """
 

	
 
        if is_windows:
 
            from hashlib import sha256
 
            return sha256(password).hexdigest() == hashed
 
        elif is_unix:
 
            import bcrypt
 
            return bcrypt.hashpw(password, hashed) == hashed
 
        else:
 
            raise Exception('Unknown or unsupported platform %s' \
 
                            % __platform__)
 

	
 

	
 
def get_crypt_password(password):
 
    return KallitheaCrypto.hash_string(password)
 

	
 

	
 
def check_password(password, hashed):
 
    return KallitheaCrypto.hash_check(password, hashed)
 

	
 

	
 
def generate_api_key(str_, salt=None):
 
    """
 
    Generates API KEY from given string
 

	
 
    :param str_:
 
    :param salt:
 
    """
 

	
 
    if salt is None:
 
        salt = _RandomNameSequence().next()
 

	
 
    return hashlib.sha1(str_ + salt).hexdigest()
 

	
 

	
 
class CookieStoreWrapper(object):
 

	
 
    def __init__(self, cookie_store):
 
        self.cookie_store = cookie_store
 

	
 
    def __repr__(self):
 
        return 'CookieStore<%s>' % (self.cookie_store)
 

	
 
    def get(self, key, other=None):
 
        if isinstance(self.cookie_store, dict):
 
            return self.cookie_store.get(key, other)
 
        elif isinstance(self.cookie_store, AuthUser):
 
            return self.cookie_store.__dict__.get(key, other)
 

	
 

	
 

	
 
def _cached_perms_data(user_id, user_is_admin, user_inherit_default_permissions,
 
                       explicit, algo):
 
    RK = 'repositories'
 
    GK = 'repositories_groups'
 
    UK = 'user_groups'
 
    GLOBAL = 'global'
 
    PERM_WEIGHTS = Permission.PERM_WEIGHTS
 
    permissions = {RK: {}, GK: {}, UK: {}, GLOBAL: set()}
 

	
 
    def _choose_perm(new_perm, cur_perm):
 
        new_perm_val = PERM_WEIGHTS[new_perm]
 
        cur_perm_val = PERM_WEIGHTS[cur_perm]
 
        if algo == 'higherwin':
 
            if new_perm_val > cur_perm_val:
 
                return new_perm
 
            return cur_perm
 
        elif algo == 'lowerwin':
 
            if new_perm_val < cur_perm_val:
 
                return new_perm
 
            return cur_perm
 

	
 
    #======================================================================
 
    # fetch default permissions
 
    #======================================================================
 
    default_user = User.get_by_username('default', cache=True)
 
    default_user_id = default_user.user_id
 

	
 
    default_repo_perms = Permission.get_default_perms(default_user_id)
 
    default_repo_groups_perms = Permission.get_default_group_perms(default_user_id)
 
    default_user_group_perms = Permission.get_default_user_group_perms(default_user_id)
 

	
 
    if user_is_admin:
 
        #==================================================================
 
        # admin user have all default rights for repositories
 
        # and groups set to admin
 
        #==================================================================
 
        permissions[GLOBAL].add('hg.admin')
 
        permissions[GLOBAL].add('hg.create.write_on_repogroup.true')
 

	
 
        # repositories
 
        for perm in default_repo_perms:
 
            r_k = perm.UserRepoToPerm.repository.repo_name
 
            p = 'repository.admin'
 
            permissions[RK][r_k] = p
 

	
 
        # repository groups
 
        for perm in default_repo_groups_perms:
 
            rg_k = perm.UserRepoGroupToPerm.group.group_name
 
            p = 'group.admin'
 
            permissions[GK][rg_k] = p
 

	
 
        # user groups
 
        for perm in default_user_group_perms:
 
            u_k = perm.UserUserGroupToPerm.user_group.users_group_name
 
            p = 'usergroup.admin'
 
            permissions[UK][u_k] = p
 
        return permissions
 

	
 
    #==================================================================
 
    # SET DEFAULTS GLOBAL, REPOS, REPOSITORY GROUPS
 
    #==================================================================
 
    uid = user_id
 

	
 
    # default global permissions taken fron the default user
 
    default_global_perms = UserToPerm.query()\
 
        .filter(UserToPerm.user_id == default_user_id)\
 
        .options(joinedload(UserToPerm.permission))
 

	
 
    for perm in default_global_perms:
 
        permissions[GLOBAL].add(perm.permission.permission_name)
 

	
 
    # defaults for repositories, taken from default user
 
    for perm in default_repo_perms:
 
        r_k = perm.UserRepoToPerm.repository.repo_name
 
        if perm.Repository.private and not (perm.Repository.user_id == uid):
 
            # disable defaults for private repos,
 
            p = 'repository.none'
 
        elif perm.Repository.user_id == uid:
 
            # set admin if owner
 
            p = 'repository.admin'
 
        else:
 
            p = perm.Permission.permission_name
 

	
 
        permissions[RK][r_k] = p
 

	
 
    # defaults for repository groups taken from default user permission
 
    # on given group
 
    for perm in default_repo_groups_perms:
 
        rg_k = perm.UserRepoGroupToPerm.group.group_name
 
        p = perm.Permission.permission_name
 
        permissions[GK][rg_k] = p
 

	
 
    # defaults for user groups taken from default user permission
 
    # on given user group
 
    for perm in default_user_group_perms:
 
        u_k = perm.UserUserGroupToPerm.user_group.users_group_name
 
        p = perm.Permission.permission_name
 
        permissions[UK][u_k] = p
 

	
 
    #======================================================================
 
    # !! OVERRIDE GLOBALS !! with user permissions if any found
 
    #======================================================================
 
    # those can be configured from groups or users explicitly
 
    _configurable = set([
 
        'hg.fork.none', 'hg.fork.repository',
 
        'hg.create.none', 'hg.create.repository',
 
        'hg.usergroup.create.false', 'hg.usergroup.create.true'
 
    ])
 

	
 
    # USER GROUPS comes first
 
    # user group global permissions
 
    user_perms_from_users_groups = Session().query(UserGroupToPerm)\
 
        .options(joinedload(UserGroupToPerm.permission))\
 
        .join((UserGroupMember, UserGroupToPerm.users_group_id ==
 
               UserGroupMember.users_group_id))\
 
        .filter(UserGroupMember.user_id == uid)\
 
        .order_by(UserGroupToPerm.users_group_id)\
 
        .all()
 
    # need to group here by groups since user can be in more than
 
    # one group
 
    _grouped = [[x, list(y)] for x, y in
 
                itertools.groupby(user_perms_from_users_groups,
 
                                  lambda x:x.users_group)]
 
    for gr, perms in _grouped:
 
        # since user can be in multiple groups iterate over them and
 
        # select the lowest permissions first (more explicit)
 
        ##TODO: do this^^
 
        if not gr.inherit_default_permissions:
 
            # NEED TO IGNORE all configurable permissions and
 
            # replace them with explicitly set
 
            permissions[GLOBAL] = permissions[GLOBAL]\
 
                                            .difference(_configurable)
 
        for perm in perms:
 
            permissions[GLOBAL].add(perm.permission.permission_name)
 

	
 
    # user specific global permissions
 
    user_perms = Session().query(UserToPerm)\
 
            .options(joinedload(UserToPerm.permission))\
 
            .filter(UserToPerm.user_id == uid).all()
 

	
 
    if not user_inherit_default_permissions:
 
        # NEED TO IGNORE all configurable permissions and
 
        # replace them with explicitly set
 
        permissions[GLOBAL] = permissions[GLOBAL]\
 
                                        .difference(_configurable)
 

	
 
        for perm in user_perms:
 
            permissions[GLOBAL].add(perm.permission.permission_name)
 
    ## END GLOBAL PERMISSIONS
 

	
 
    #======================================================================
 
    # !! PERMISSIONS FOR REPOSITORIES !!
 
    #======================================================================
 
    #======================================================================
 
    # check if user is part of user groups for this repository and
 
    # fill in his permission from it. _choose_perm decides of which
 
    # permission should be selected based on selected method
 
    #======================================================================
 

	
 
    # user group for repositories permissions
 
    user_repo_perms_from_users_groups = \
 
     Session().query(UserGroupRepoToPerm, Permission, Repository,)\
 
        .join((Repository, UserGroupRepoToPerm.repository_id ==
 
               Repository.repo_id))\
 
        .join((Permission, UserGroupRepoToPerm.permission_id ==
 
               Permission.permission_id))\
 
        .join((UserGroupMember, UserGroupRepoToPerm.users_group_id ==
 
               UserGroupMember.users_group_id))\
 
        .filter(UserGroupMember.user_id == uid)\
 
        .all()
 

	
 
    multiple_counter = collections.defaultdict(int)
 
    for perm in user_repo_perms_from_users_groups:
 
        r_k = perm.UserGroupRepoToPerm.repository.repo_name
 
        multiple_counter[r_k] += 1
 
        p = perm.Permission.permission_name
 
        cur_perm = permissions[RK][r_k]
 

	
 
        if perm.Repository.user_id == uid:
 
            # set admin if owner
 
            p = 'repository.admin'
 
        else:
 
            if multiple_counter[r_k] > 1:
 
                p = _choose_perm(p, cur_perm)
 
        permissions[RK][r_k] = p
 

	
 
    # user explicit permissions for repositories, overrides any specified
 
    # by the group permission
 
    user_repo_perms = Permission.get_default_perms(uid)
 
    for perm in user_repo_perms:
 
        r_k = perm.UserRepoToPerm.repository.repo_name
 
        cur_perm = permissions[RK][r_k]
 
        # set admin if owner
 
        if perm.Repository.user_id == uid:
 
            p = 'repository.admin'
 
        else:
 
            p = perm.Permission.permission_name
 
            if not explicit:
 
                p = _choose_perm(p, cur_perm)
 
        permissions[RK][r_k] = p
 

	
 
    #======================================================================
 
    # !! PERMISSIONS FOR REPOSITORY GROUPS !!
 
    #======================================================================
 
    #======================================================================
 
    # check if user is part of user groups for this repository groups and
 
    # fill in his permission from it. _choose_perm decides of which
 
    # permission should be selected based on selected method
 
    #======================================================================
 
    # user group for repo groups permissions
 
    user_repo_group_perms_from_users_groups = \
 
     Session().query(UserGroupRepoGroupToPerm, Permission, RepoGroup)\
 
     .join((RepoGroup, UserGroupRepoGroupToPerm.group_id == RepoGroup.group_id))\
 
     .join((Permission, UserGroupRepoGroupToPerm.permission_id
 
            == Permission.permission_id))\
 
     .join((UserGroupMember, UserGroupRepoGroupToPerm.users_group_id
 
            == UserGroupMember.users_group_id))\
 
     .filter(UserGroupMember.user_id == uid)\
 
     .all()
 

	
 
    multiple_counter = collections.defaultdict(int)
 
    for perm in user_repo_group_perms_from_users_groups:
 
        g_k = perm.UserGroupRepoGroupToPerm.group.group_name
 
        multiple_counter[g_k] += 1
 
        p = perm.Permission.permission_name
 
        cur_perm = permissions[GK][g_k]
 
        if multiple_counter[g_k] > 1:
 
            p = _choose_perm(p, cur_perm)
 
        permissions[GK][g_k] = p
 

	
 
    # user explicit permissions for repository groups
 
    user_repo_groups_perms = Permission.get_default_group_perms(uid)
 
    for perm in user_repo_groups_perms:
 
        rg_k = perm.UserRepoGroupToPerm.group.group_name
 
        p = perm.Permission.permission_name
 
        cur_perm = permissions[GK][rg_k]
 
        if not explicit:
 
            p = _choose_perm(p, cur_perm)
 
        permissions[GK][rg_k] = p
 

	
 
    #======================================================================
 
    # !! PERMISSIONS FOR USER GROUPS !!
 
    #======================================================================
 
    # user group for user group permissions
 
    user_group_user_groups_perms = \
 
     Session().query(UserGroupUserGroupToPerm, Permission, UserGroup)\
 
     .join((UserGroup, UserGroupUserGroupToPerm.target_user_group_id
 
            == UserGroup.users_group_id))\
 
     .join((Permission, UserGroupUserGroupToPerm.permission_id
 
            == Permission.permission_id))\
 
     .join((UserGroupMember, UserGroupUserGroupToPerm.user_group_id
 
            == UserGroupMember.users_group_id))\
 
     .filter(UserGroupMember.user_id == uid)\
 
     .all()
 

	
 
    multiple_counter = collections.defaultdict(int)
 
    for perm in user_group_user_groups_perms:
 
        g_k = perm.UserGroupUserGroupToPerm.target_user_group.users_group_name
 
        multiple_counter[g_k] += 1
 
        p = perm.Permission.permission_name
 
        cur_perm = permissions[UK][g_k]
 
        if multiple_counter[g_k] > 1:
 
            p = _choose_perm(p, cur_perm)
 
        permissions[UK][g_k] = p
 

	
 
    #user explicit permission for user groups
 
    user_user_groups_perms = Permission.get_default_user_group_perms(uid)
 
    for perm in user_user_groups_perms:
 
        u_k = perm.UserUserGroupToPerm.user_group.users_group_name
 
        p = perm.Permission.permission_name
 
        cur_perm = permissions[UK][u_k]
 
        if not explicit:
 
            p = _choose_perm(p, cur_perm)
 
        permissions[UK][u_k] = p
 

	
 
    return permissions
 

	
 

	
 
def allowed_api_access(controller_name, whitelist=None, api_key=None):
 
    """
 
    Check if given controller_name is in whitelist API access
 
    """
 
    if not whitelist:
 
        from kallithea import CONFIG
 
        whitelist = aslist(CONFIG.get('api_access_controllers_whitelist'),
 
                           sep=',')
 
        log.debug('whitelist of API access is: %s' % (whitelist))
 
    api_access_valid = controller_name in whitelist
 
    if api_access_valid:
 
        log.debug('controller:%s is in API whitelist' % (controller_name))
 
    else:
 
        msg = 'controller: %s is *NOT* in API whitelist' % (controller_name)
 
        if api_key:
 
            #if we use API key and don't have access it's a warning
 
            log.warning(msg)
 
        else:
 
            log.debug(msg)
 
    return api_access_valid
 

	
 

	
 
class AuthUser(object):
 
    """
 
    A simple object that handles all attributes of user in Kallithea
 

	
 
    It does lookup based on API key,given user, or user present in session
 
    Then it fills all required information for such user. It also checks if
 
    anonymous access is enabled and if so, it returns default user as logged in
 
    """
 

	
 
    def __init__(self, user_id=None, api_key=None, username=None, ip_addr=None):
 

	
 
        self.user_id = user_id
 
        self._api_key = api_key
 

	
 
        self.api_key = None
 
        self.username = username
 
        self.ip_addr = ip_addr
 
        self.name = ''
 
        self.lastname = ''
 
        self.email = ''
 
        self.is_authenticated = False
 
        self.admin = False
 
        self.inherit_default_permissions = False
 

	
 
        self.propagate_data()
 
        self._instance = None
 

	
 
    @LazyProperty
 
    def permissions(self):
 
        return self.get_perms(user=self, cache=False)
 

	
 
    @property
 
    def api_keys(self):
 
        return self.get_api_keys()
 

	
 
    def propagate_data(self):
 
        user_model = UserModel()
 
        self.anonymous_user = User.get_default_user(cache=True)
 
        is_user_loaded = False
 

	
 
        # lookup by userid
 
        if self.user_id is not None and self.user_id != self.anonymous_user.user_id:
 
            log.debug('Auth User lookup by USER ID %s' % self.user_id)
 
            is_user_loaded = user_model.fill_data(self, user_id=self.user_id)
 

	
 
        # try go get user by api key
 
        elif self._api_key and self._api_key != self.anonymous_user.api_key:
 
            log.debug('Auth User lookup by API KEY %s' % self._api_key)
 
            is_user_loaded = user_model.fill_data(self, api_key=self._api_key)
 

	
 
        # lookup by username
 
        elif self.username:
 
            log.debug('Auth User lookup by USER NAME %s' % self.username)
 
            is_user_loaded = user_model.fill_data(self, username=self.username)
 
        else:
 
            log.debug('No data in %s that could been used to log in' % self)
 

	
 
        if not is_user_loaded:
 
            # if we cannot authenticate user try anonymous
 
            if self.anonymous_user.active:
 
                user_model.fill_data(self, user_id=self.anonymous_user.user_id)
 
                # then we set this user is logged in
 
                self.is_authenticated = True
 
            else:
 
                self.user_id = None
 
                self.username = None
 
                self.is_authenticated = False
 

	
 
        if not self.username:
 
            self.username = 'None'
 

	
 
        log.debug('Auth User is now %s' % self)
 

	
 
    def get_perms(self, user, explicit=True, algo='higherwin', cache=False):
 
        """
 
        Fills user permission attribute with permissions taken from database
 
        works for permissions given for repositories, and for permissions that
 
        are granted to groups
 

	
 
        :param user: instance of User object from database
 
        :param explicit: In case there are permissions both for user and a group
 
            that user is part of, explicit flag will defiine if user will
 
            explicitly override permissions from group, if it's False it will
 
            make decision based on the algo
 
        :param algo: algorithm to decide what permission should be choose if
 
            it's multiple defined, eg user in two different groups. It also
 
            decides if explicit flag is turned off how to specify the permission
 
            for case when user is in a group + have defined separate permission
 
        """
 
        user_id = user.user_id
 
        user_is_admin = user.is_admin
 
        user_inherit_default_permissions = user.inherit_default_permissions
 

	
 
        log.debug('Getting PERMISSION tree')
 
        compute = conditional_cache('short_term', 'cache_desc',
 
                                    condition=cache, func=_cached_perms_data)
 
        return compute(user_id, user_is_admin,
 
                       user_inherit_default_permissions, explicit, algo)
 

	
 
    def get_api_keys(self):
 
        api_keys = [self.api_key]
 
        for api_key in UserApiKeys.query()\
 
                .filter(UserApiKeys.user_id == self.user_id)\
 
                .filter(or_(UserApiKeys.expires == -1,
 
                            UserApiKeys.expires >= time.time())).all():
 
            api_keys.append(api_key.api_key)
 

	
 
        return api_keys
 

	
 
    @property
 
    def is_admin(self):
 
        return self.admin
 

	
 
    @property
 
    def repositories_admin(self):
 
        """
 
        Returns list of repositories you're an admin of
 
        """
 
        return [x[0] for x in self.permissions['repositories'].iteritems()
 
                if x[1] == 'repository.admin']
 

	
 
    @property
 
    def repository_groups_admin(self):
 
        """
 
        Returns list of repository groups you're an admin of
 
        """
 
        return [x[0] for x in self.permissions['repositories_groups'].iteritems()
 
                if x[1] == 'group.admin']
 

	
 
    @property
 
    def user_groups_admin(self):
 
        """
 
        Returns list of user groups you're an admin of
 
        """
 
        return [x[0] for x in self.permissions['user_groups'].iteritems()
 
                if x[1] == 'usergroup.admin']
 

	
 
    @property
 
    def ip_allowed(self):
 
        """
 
        Checks if ip_addr used in constructor is allowed from defined list of
 
        allowed ip_addresses for user
 

	
 
        :returns: boolean, True if ip is in allowed ip range
 
        """
 
        # check IP
 
        inherit = self.inherit_default_permissions
 
        return AuthUser.check_ip_allowed(self.user_id, self.ip_addr,
 
                                         inherit_from_default=inherit)
 

	
 
    @classmethod
 
    def check_ip_allowed(cls, user_id, ip_addr, inherit_from_default):
 
        allowed_ips = AuthUser.get_allowed_ips(user_id, cache=True,
 
                        inherit_from_default=inherit_from_default)
 
        if check_ip_access(source_ip=ip_addr, allowed_ips=allowed_ips):
 
            log.debug('IP:%s is in range of %s' % (ip_addr, allowed_ips))
 
            return True
 
        else:
 
            log.info('Access for IP:%s forbidden, '
 
                     'not in %s' % (ip_addr, allowed_ips))
 
            return False
 

	
 
    def __repr__(self):
 
        return "<AuthUser('id:%s[%s] ip:%s auth:%s')>"\
 
            % (self.user_id, self.username, self.ip_addr, self.is_authenticated)
 

	
 
    def set_authenticated(self, authenticated=True):
 
        if self.user_id != self.anonymous_user.user_id:
 
            self.is_authenticated = authenticated
 

	
 
    def get_cookie_store(self):
 
        return {'username': self.username,
 
                'user_id': self.user_id,
 
                'is_authenticated': self.is_authenticated}
 

	
 
    @classmethod
 
    def from_cookie_store(cls, cookie_store):
 
        """
 
        Creates AuthUser from a cookie store
 

	
 
        :param cls:
 
        :param cookie_store:
 
        """
 
        user_id = cookie_store.get('user_id')
 
        username = cookie_store.get('username')
 
        api_key = cookie_store.get('api_key')
 
        return AuthUser(user_id, api_key, username)
 

	
 
    @classmethod
 
    def get_allowed_ips(cls, user_id, cache=False, inherit_from_default=False):
 
        _set = set()
 

	
 
        if inherit_from_default:
 
            default_ips = UserIpMap.query().filter(UserIpMap.user ==
 
                                            User.get_default_user(cache=True))
 
            if cache:
 
                default_ips = default_ips.options(FromCache("sql_cache_short",
 
                                                  "get_user_ips_default"))
 

	
 
            # populate from default user
 
            for ip in default_ips:
 
                try:
 
                    _set.add(ip.ip_addr)
 
                except ObjectDeletedError:
 
                    # since we use heavy caching sometimes it happens that we get
 
                    # deleted objects here, we just skip them
 
                    pass
 

	
 
        user_ips = UserIpMap.query().filter(UserIpMap.user_id == user_id)
 
        if cache:
 
            user_ips = user_ips.options(FromCache("sql_cache_short",
 
                                                  "get_user_ips_%s" % user_id))
 

	
 
        for ip in user_ips:
 
            try:
 
                _set.add(ip.ip_addr)
 
            except ObjectDeletedError:
 
                # since we use heavy caching sometimes it happens that we get
 
                # deleted objects here, we just skip them
 
                pass
 
        return _set or set(['0.0.0.0/0', '::/0'])
 

	
 

	
 
def set_available_permissions(config):
 
    """
 
    This function will propagate pylons globals with all available defined
 
    permission given in db. We don't want to check each time from db for new
 
    permissions since adding a new permission also requires application restart
 
    ie. to decorate new views with the newly created permission
 

	
 
    :param config: current pylons config instance
 

	
 
    """
 
    log.info('getting information about all available permissions')
 
    try:
 
        sa = meta.Session
 
        all_perms = sa.query(Permission).all()
 
        config['available_permissions'] = [x.permission_name for x in all_perms]
 
    except Exception:
 
        log.error(traceback.format_exc())
 
    finally:
 
        meta.Session.remove()
 

	
 

	
 
#==============================================================================
 
# CHECK DECORATORS
 
#==============================================================================
 
class LoginRequired(object):
 
    """
 
    Must be logged in to execute this function else
 
    redirect to login page
 

	
 
    :param api_access: if enabled this checks only for valid auth token
 
        and grants access based on valid token
 
    """
 

	
 
    def __init__(self, api_access=False):
 
        self.api_access = api_access
 

	
 
    def __call__(self, func):
 
        return decorator(self.__wrapper, func)
 

	
 
    def __wrapper(self, func, *fargs, **fkwargs):
 
        cls = fargs[0]
 
        user = cls.authuser
 
        loc = "%s:%s" % (cls.__class__.__name__, func.__name__)
 

	
 
        # check if our IP is allowed
 
        ip_access_valid = True
 
        if not user.ip_allowed:
 
            from kallithea.lib import helpers as h
 
            h.flash(h.literal(_('IP %s not allowed' % (user.ip_addr))),
 
                    category='warning')
 
            ip_access_valid = False
 

	
 
        # check if we used an APIKEY and it's a valid one
 
        # defined whitelist of controllers which API access will be enabled
 
        _api_key = request.GET.get('api_key', '')
 
        api_access_valid = allowed_api_access(loc, api_key=_api_key)
 

	
 
        # explicit controller is enabled or API is in our whitelist
 
        if self.api_access or api_access_valid:
 
            log.debug('Checking API KEY access for %s' % cls)
 
            if _api_key and _api_key in user.api_keys:
 
                api_access_valid = True
 
                log.debug('API KEY ****%s is VALID' % _api_key[-4:])
 
            else:
 
                api_access_valid = False
 
                if not _api_key:
 
                    log.debug("API KEY *NOT* present in request")
 
                else:
 
                    log.warning("API KEY ****%s *NOT* valid" % _api_key[-4:])
 

	
 
        log.debug('Checking if %s is authenticated @ %s' % (user.username, loc))
 
        reason = 'RegularAuth' if user.is_authenticated else 'APIAuth'
 

	
 
        if ip_access_valid and (user.is_authenticated or api_access_valid):
 
            log.info('user %s authenticating with:%s IS authenticated on func %s '
 
                     % (user, reason, loc)
 
            )
 
            return func(*fargs, **fkwargs)
 
        else:
 
            log.warning('user %s authenticating with:%s NOT authenticated on func: %s: '
 
                     'IP_ACCESS:%s API_ACCESS:%s'
 
                     % (user, reason, loc, ip_access_valid, api_access_valid)
 
            )
 
            p = url.current()
 

	
 
            log.debug('redirecting to login page with %s' % p)
 
            return redirect(url('login_home', came_from=p))
 

	
 

	
 
class NotAnonymous(object):
 
    """
 
    Must be logged in to execute this function else
 
    redirect to login page"""
 

	
 
    def __call__(self, func):
 
        return decorator(self.__wrapper, func)
 

	
 
    def __wrapper(self, func, *fargs, **fkwargs):
 
        cls = fargs[0]
 
        self.user = cls.authuser
 

	
 
        log.debug('Checking if user is not anonymous @%s' % cls)
 

	
 
        anonymous = self.user.username == User.DEFAULT_USER
 

	
 
        if anonymous:
 
            p = url.current()
 

	
 
            import kallithea.lib.helpers as h
 
            h.flash(_('You need to be a registered user to '
 
                      'perform this action'),
 
                    category='warning')
 
            return redirect(url('login_home', came_from=p))
 
        else:
 
            return func(*fargs, **fkwargs)
 

	
 

	
 
class PermsDecorator(object):
 
    """Base class for controller decorators"""
 

	
 
    def __init__(self, *required_perms):
 
        self.required_perms = set(required_perms)
 
        self.user_perms = None
 

	
 
    def __call__(self, func):
 
        return decorator(self.__wrapper, func)
 

	
 
    def __wrapper(self, func, *fargs, **fkwargs):
 
        cls = fargs[0]
 
        self.user = cls.authuser
 
        self.user_perms = self.user.permissions
 
        log.debug('checking %s permissions %s for %s %s',
 
           self.__class__.__name__, self.required_perms, cls, self.user)
 

	
 
        if self.check_permissions():
 
            log.debug('Permission granted for %s %s' % (cls, self.user))
 
            return func(*fargs, **fkwargs)
 

	
 
        else:
 
            log.debug('Permission denied for %s %s' % (cls, self.user))
 
            anonymous = self.user.username == User.DEFAULT_USER
 

	
 
            if anonymous:
 
                p = url.current()
 

	
 
                import kallithea.lib.helpers as h
 
                h.flash(_('You need to be a signed in to '
 
                          'view this page'),
 
                        category='warning')
 
                return redirect(url('login_home', came_from=p))
 

	
 
            else:
 
                # redirect with forbidden ret code
 
                return abort(403)
 

	
 
    def check_permissions(self):
 
        """Dummy function for overriding"""
 
        raise Exception('You have to write this function in child class')
 

	
 

	
 
class HasPermissionAllDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for all given predicates. All of them
 
    have to be meet in order to fulfill the request
 
    """
 

	
 
    def check_permissions(self):
 
        if self.required_perms.issubset(self.user_perms.get('global')):
 
            return True
 
        return False
 

	
 

	
 
class HasPermissionAnyDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for any of given predicates. In order to
 
    fulfill the request any of predicates must be meet
 
    """
 

	
 
    def check_permissions(self):
 
        if self.required_perms.intersection(self.user_perms.get('global')):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoPermissionAllDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for all given predicates for specific
 
    repository. All of them have to be meet in order to fulfill the request
 
    """
 

	
 
    def check_permissions(self):
 
        repo_name = get_repo_slug(request)
 
        try:
 
            user_perms = set([self.user_perms['repositories'][repo_name]])
 
        except KeyError:
 
            return False
 
        if self.required_perms.issubset(user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoPermissionAnyDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for any of given predicates for specific
 
    repository. In order to fulfill the request any of predicates must be meet
 
    """
 

	
 
    def check_permissions(self):
 
        repo_name = get_repo_slug(request)
 
        try:
 
            user_perms = set([self.user_perms['repositories'][repo_name]])
 
        except KeyError:
 
            return False
 

	
 
        if self.required_perms.intersection(user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoGroupPermissionAllDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for all given predicates for specific
 
    repository group. All of them have to be meet in order to fulfill the request
 
    """
 

	
 
    def check_permissions(self):
 
        group_name = get_repo_group_slug(request)
 
        try:
 
            user_perms = set([self.user_perms['repositories_groups'][group_name]])
 
        except KeyError:
 
            return False
 

	
 
        if self.required_perms.issubset(user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoGroupPermissionAnyDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for any of given predicates for specific
 
    repository group. In order to fulfill the request any of predicates must be meet
 
    """
 

	
 
    def check_permissions(self):
 
        group_name = get_repo_group_slug(request)
 
        try:
 
            user_perms = set([self.user_perms['repositories_groups'][group_name]])
 
        except KeyError:
 
            return False
 

	
 
        if self.required_perms.intersection(user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasUserGroupPermissionAllDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for all given predicates for specific
 
    user group. All of them have to be meet in order to fulfill the request
 
    """
 

	
 
    def check_permissions(self):
 
        group_name = get_user_group_slug(request)
 
        try:
 
            user_perms = set([self.user_perms['user_groups'][group_name]])
 
        except KeyError:
 
            return False
 

	
 
        if self.required_perms.issubset(user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasUserGroupPermissionAnyDecorator(PermsDecorator):
 
    """
 
    Checks for access permission for any of given predicates for specific
 
    user group. In order to fulfill the request any of predicates must be meet
 
    """
 

	
 
    def check_permissions(self):
 
        group_name = get_user_group_slug(request)
 
        try:
 
            user_perms = set([self.user_perms['user_groups'][group_name]])
 
        except KeyError:
 
            return False
 

	
 
        if self.required_perms.intersection(user_perms):
 
            return True
 
        return False
 

	
 

	
 
#==============================================================================
 
# CHECK FUNCTIONS
 
#==============================================================================
 
class PermsFunction(object):
 
    """Base function for other check functions"""
 

	
 
    def __init__(self, *perms):
 
        self.required_perms = set(perms)
 
        self.user_perms = None
 
        self.repo_name = None
 
        self.group_name = None
 

	
 
    def __call__(self, check_location='', user=None):
 
        if not user:
 
            #TODO: remove this someday,put as user as attribute here
 
            user = request.user
 

	
 
        # init auth user if not already given
 
        if not isinstance(user, AuthUser):
 
            user = AuthUser(user.user_id)
 

	
 
        cls_name = self.__class__.__name__
 
        check_scope = {
 
            'HasPermissionAll': '',
 
            'HasPermissionAny': '',
 
            'HasRepoPermissionAll': 'repo:%s' % self.repo_name,
 
            'HasRepoPermissionAny': 'repo:%s' % self.repo_name,
 
            'HasRepoGroupPermissionAll': 'group:%s' % self.group_name,
 
            'HasRepoGroupPermissionAny': 'group:%s' % self.group_name,
 
        }.get(cls_name, '?')
 
        log.debug('checking cls:%s %s usr:%s %s @ %s', cls_name,
 
                  self.required_perms, user, check_scope,
 
                  check_location or 'unspecified location')
 
        if not user:
 
            log.debug('Empty request user')
 
            return False
 
        self.user_perms = user.permissions
 
        if self.check_permissions():
 
            log.debug('Permission to %s granted for user: %s @ %s'
 
                      % (check_scope, user,
 
                         check_location or 'unspecified location'))
 
            return True
 

	
 
        else:
 
            log.debug('Permission to %s denied for user: %s @ %s'
 
                      % (check_scope, user,
 
                         check_location or 'unspecified location'))
 
            return False
 

	
 
    def check_permissions(self):
 
        """Dummy function for overriding"""
 
        raise Exception('You have to write this function in child class')
 

	
 

	
 
class HasPermissionAll(PermsFunction):
 
    def check_permissions(self):
 
        if self.required_perms.issubset(self.user_perms.get('global')):
 
            return True
 
        return False
 

	
 

	
 
class HasPermissionAny(PermsFunction):
 
    def check_permissions(self):
 
        if self.required_perms.intersection(self.user_perms.get('global')):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoPermissionAll(PermsFunction):
 
    def __call__(self, repo_name=None, check_location='', user=None):
 
        self.repo_name = repo_name
 
        return super(HasRepoPermissionAll, self).__call__(check_location, user)
 

	
 
    def check_permissions(self):
 
        if not self.repo_name:
 
            self.repo_name = get_repo_slug(request)
 

	
 
        try:
 
            self._user_perms = set(
 
                [self.user_perms['repositories'][self.repo_name]]
 
            )
 
        except KeyError:
 
            return False
 
        if self.required_perms.issubset(self._user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoPermissionAny(PermsFunction):
 
    def __call__(self, repo_name=None, check_location='', user=None):
 
        self.repo_name = repo_name
 
        return super(HasRepoPermissionAny, self).__call__(check_location, user)
 

	
 
    def check_permissions(self):
 
        if not self.repo_name:
 
            self.repo_name = get_repo_slug(request)
 

	
 
        try:
 
            self._user_perms = set(
 
                [self.user_perms['repositories'][self.repo_name]]
 
            )
 
        except KeyError:
 
            return False
 
        if self.required_perms.intersection(self._user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoGroupPermissionAny(PermsFunction):
 
    def __call__(self, group_name=None, check_location='', user=None):
 
        self.group_name = group_name
 
        return super(HasRepoGroupPermissionAny, self).__call__(check_location, user)
 

	
 
    def check_permissions(self):
 
        try:
 
            self._user_perms = set(
 
                [self.user_perms['repositories_groups'][self.group_name]]
 
            )
 
        except KeyError:
 
            return False
 
        if self.required_perms.intersection(self._user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoGroupPermissionAll(PermsFunction):
 
    def __call__(self, group_name=None, check_location='', user=None):
 
        self.group_name = group_name
 
        return super(HasRepoGroupPermissionAll, self).__call__(check_location, user)
 

	
 
    def check_permissions(self):
 
        try:
 
            self._user_perms = set(
 
                [self.user_perms['repositories_groups'][self.group_name]]
 
            )
 
        except KeyError:
 
            return False
 
        if self.required_perms.issubset(self._user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasUserGroupPermissionAny(PermsFunction):
 
    def __call__(self, user_group_name=None, check_location='', user=None):
 
        self.user_group_name = user_group_name
 
        return super(HasUserGroupPermissionAny, self).__call__(check_location, user)
 

	
 
    def check_permissions(self):
 
        try:
 
            self._user_perms = set(
 
                [self.user_perms['user_groups'][self.user_group_name]]
 
            )
 
        except KeyError:
 
            return False
 
        if self.required_perms.intersection(self._user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasUserGroupPermissionAll(PermsFunction):
 
    def __call__(self, user_group_name=None, check_location='', user=None):
 
        self.user_group_name = user_group_name
 
        return super(HasUserGroupPermissionAll, self).__call__(check_location, user)
 

	
 
    def check_permissions(self):
 
        try:
 
            self._user_perms = set(
 
                [self.user_perms['user_groups'][self.user_group_name]]
 
            )
 
        except KeyError:
 
            return False
 
        if self.required_perms.issubset(self._user_perms):
 
            return True
 
        return False
 

	
 

	
 
#==============================================================================
 
# SPECIAL VERSION TO HANDLE MIDDLEWARE AUTH
 
#==============================================================================
 
class HasPermissionAnyMiddleware(object):
 
    def __init__(self, *perms):
 
        self.required_perms = set(perms)
 

	
 
    def __call__(self, user, repo_name):
 
        # repo_name MUST be unicode, since we handle keys in permission
 
        # dict by unicode
 
        repo_name = safe_unicode(repo_name)
 
        usr = AuthUser(user.user_id)
 
        try:
 
            self.user_perms = set([usr.permissions['repositories'][repo_name]])
 
        except Exception:
 
            log.error('Exception while accessing permissions %s' %
 
                      traceback.format_exc())
 
            self.user_perms = set()
 
        self.user_perms = set([usr.permissions['repositories'][repo_name]])
 
        self.username = user.username
 
        self.repo_name = repo_name
 
        return self.check_permissions()
 

	
 
    def check_permissions(self):
 
        log.debug('checking VCS protocol '
 
                  'permissions %s for user:%s repository:%s', self.user_perms,
 
                                                self.username, self.repo_name)
 
        if self.required_perms.intersection(self.user_perms):
 
            log.debug('Permission to repo: %s granted for user: %s @ %s'
 
                      % (self.repo_name, self.username, 'PermissionMiddleware'))
 
            return True
 
        log.debug('Permission to repo: %s denied for user: %s @ %s'
 
                  % (self.repo_name, self.username, 'PermissionMiddleware'))
 
        return False
 

	
 

	
 
#==============================================================================
 
# SPECIAL VERSION TO HANDLE API AUTH
 
#==============================================================================
 
class _BaseApiPerm(object):
 
    def __init__(self, *perms):
 
        self.required_perms = set(perms)
 

	
 
    def __call__(self, check_location=None, user=None, repo_name=None,
 
                 group_name=None):
 
        cls_name = self.__class__.__name__
 
        check_scope = 'user:%s' % (user)
 
        if repo_name:
 
            check_scope += ', repo:%s' % (repo_name)
 

	
 
        if group_name:
 
            check_scope += ', repo group:%s' % (group_name)
 

	
 
        log.debug('checking cls:%s %s %s @ %s'
 
                  % (cls_name, self.required_perms, check_scope, check_location))
 
        if not user:
 
            log.debug('Empty User passed into arguments')
 
            return False
 

	
 
        ## process user
 
        if not isinstance(user, AuthUser):
 
            user = AuthUser(user.user_id)
 
        if not check_location:
 
            check_location = 'unspecified'
 
        if self.check_permissions(user.permissions, repo_name, group_name):
 
            log.debug('Permission to %s granted for user: %s @ %s'
 
                      % (check_scope, user, check_location))
 
            return True
 

	
 
        else:
 
            log.debug('Permission to %s denied for user: %s @ %s'
 
                      % (check_scope, user, check_location))
 
            return False
 

	
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        """
 
        implement in child class should return True if permissions are ok,
 
        False otherwise
 

	
 
        :param perm_defs: dict with permission definitions
 
        :param repo_name: repo name
 
        """
 
        raise NotImplementedError()
 

	
 

	
 
class HasPermissionAllApi(_BaseApiPerm):
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        if self.required_perms.issubset(perm_defs.get('global')):
 
            return True
 
        return False
 

	
 

	
 
class HasPermissionAnyApi(_BaseApiPerm):
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        if self.required_perms.intersection(perm_defs.get('global')):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoPermissionAllApi(_BaseApiPerm):
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        try:
 
            _user_perms = set([perm_defs['repositories'][repo_name]])
 
        except KeyError:
 
            log.warning(traceback.format_exc())
 
            return False
 
        if self.required_perms.issubset(_user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoPermissionAnyApi(_BaseApiPerm):
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        try:
 
            _user_perms = set([perm_defs['repositories'][repo_name]])
 
        except KeyError:
 
            log.warning(traceback.format_exc())
 
            return False
 
        if self.required_perms.intersection(_user_perms):
 
            return True
 
        return False
 

	
 

	
 
class HasRepoGroupPermissionAnyApi(_BaseApiPerm):
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        try:
 
            _user_perms = set([perm_defs['repositories_groups'][group_name]])
 
        except KeyError:
 
            log.warning(traceback.format_exc())
 
            return False
 
        if self.required_perms.intersection(_user_perms):
 
            return True
 
        return False
 

	
 
class HasRepoGroupPermissionAllApi(_BaseApiPerm):
 
    def check_permissions(self, perm_defs, repo_name=None, group_name=None):
 
        try:
 
            _user_perms = set([perm_defs['repositories_groups'][group_name]])
 
        except KeyError:
 
            log.warning(traceback.format_exc())
 
            return False
 
        if self.required_perms.issubset(_user_perms):
 
            return True
 
        return False
 

	
 
def check_ip_access(source_ip, allowed_ips=None):
 
    """
 
    Checks if source_ip is a subnet of any of allowed_ips.
 

	
 
    :param source_ip:
 
    :param allowed_ips: list of allowed ips together with mask
 
    """
 
    from kallithea.lib import ipaddr
 
    log.debug('checking if ip:%s is subnet of %s' % (source_ip, allowed_ips))
 
    if isinstance(allowed_ips, (tuple, list, set)):
 
        for ip in allowed_ips:
 
            try:
 
                if ipaddr.IPAddress(source_ip) in ipaddr.IPNetwork(ip):
 
                    log.debug('IP %s is network %s' %
 
                              (ipaddr.IPAddress(source_ip), ipaddr.IPNetwork(ip)))
 
                    return True
 
                # for any case we cannot determine the IP, don't crash just
 
                # skip it and log as error, we want to say forbidden still when
 
                # sending bad IP
 
            except Exception:
 
                log.error(traceback.format_exc())
 
                continue
 
            if ipaddr.IPAddress(source_ip) in ipaddr.IPNetwork(ip):
 
                log.debug('IP %s is network %s' %
 
                          (ipaddr.IPAddress(source_ip), ipaddr.IPNetwork(ip)))
 
                return True
 
    return False
kallithea/lib/auth_modules/__init__.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
Authentication modules
 
"""
 

	
 
import logging
 
import traceback
 

	
 
from kallithea import EXTERN_TYPE_INTERNAL
 
from kallithea.lib.compat import importlib
 
from kallithea.lib.utils2 import str2bool
 
from kallithea.lib.compat import formatted_json, hybrid_property
 
from kallithea.lib.auth import PasswordGenerator
 
from kallithea.model.user import UserModel
 
from kallithea.model.db import Setting, User
 
from kallithea.model.meta import Session
 
from kallithea.model.user_group import UserGroupModel
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class LazyFormencode(object):
 
    def __init__(self, formencode_obj, *args, **kwargs):
 
        self.formencode_obj = formencode_obj
 
        self.args = args
 
        self.kwargs = kwargs
 

	
 
    def __call__(self, *args, **kwargs):
 
        from inspect import isfunction
 
        formencode_obj = self.formencode_obj
 
        if isfunction(formencode_obj):
 
            #case we wrap validators into functions
 
            formencode_obj = self.formencode_obj(*args, **kwargs)
 
        return formencode_obj(*self.args, **self.kwargs)
 

	
 

	
 
class KallitheaAuthPluginBase(object):
 
    auth_func_attrs = {
 
        "username": "unique username",
 
        "firstname": "first name",
 
        "lastname": "last name",
 
        "email": "email address",
 
        "groups": '["list", "of", "groups"]',
 
        "extern_name": "name in external source of record",
 
        "extern_type": "type of external source of record",
 
        "admin": 'True|False defines if user should be Kallithea super admin',
 
        "active": 'True|False defines active state of user internally for Kallithea',
 
        "active_from_extern": "True|False\None, active state from the external auth, "
 
                              "None means use definition from Kallithea extern_type active value"
 
    }
 

	
 
    @property
 
    def validators(self):
 
        """
 
        Exposes Kallithea validators modules
 
        """
 
        # this is a hack to overcome issues with pylons threadlocals and
 
        # translator object _() not beein registered properly.
 
        class LazyCaller(object):
 
            def __init__(self, name):
 
                self.validator_name = name
 

	
 
            def __call__(self, *args, **kwargs):
 
                from kallithea.model import validators as v
 
                obj = getattr(v, self.validator_name)
 
                #log.debug('Initializing lazy formencode object: %s' % obj)
 
                return LazyFormencode(obj, *args, **kwargs)
 

	
 

	
 
        class ProxyGet(object):
 
            def __getattribute__(self, name):
 
                return LazyCaller(name)
 

	
 
        return ProxyGet()
 

	
 
    @hybrid_property
 
    def name(self):
 
        """
 
        Returns the name of this authentication plugin.
 

	
 
        :returns: string
 
        """
 
        raise NotImplementedError("Not implemented in base class")
 

	
 
    @hybrid_property
 
    def is_container_auth(self):
 
        """
 
        Returns bool if this module uses container auth.
 

	
 
        This property will trigger an automatic call to authenticate on
 
        a visit to the website or during a push/pull.
 

	
 
        :returns: bool
 
        """
 
        return False
 

	
 
    def accepts(self, user, accepts_empty=True):
 
        """
 
        Checks if this authentication module should accept a request for
 
        the current user.
 

	
 
        :param user: user object fetched using plugin's get_user() method.
 
        :param accepts_empty: if True accepts don't allow the user to be empty
 
        :returns: boolean
 
        """
 
        plugin_name = self.name
 
        if not user and not accepts_empty:
 
            log.debug('User is empty not allowed to authenticate')
 
            return False
 

	
 
        if user and user.extern_type and user.extern_type != plugin_name:
 
            log.debug('User %s should authenticate using %s this is %s, skipping'
 
                      % (user, user.extern_type, plugin_name))
 

	
 
            return False
 
        return True
 

	
 
    def get_user(self, username=None, **kwargs):
 
        """
 
        Helper method for user fetching in plugins, by default it's using
 
        simple fetch by username, but this method can be custimized in plugins
 
        eg. container auth plugin to fetch user by environ params
 

	
 
        :param username: username if given to fetch from database
 
        :param kwargs: extra arguments needed for user fetching.
 
        """
 
        user = None
 
        log.debug('Trying to fetch user `%s` from Kallithea database'
 
                  % (username))
 
        if username:
 
            user = User.get_by_username(username)
 
            if not user:
 
                log.debug('Fallback to fetch user in case insensitive mode')
 
                user = User.get_by_username(username, case_insensitive=True)
 
        else:
 
            log.debug('provided username:`%s` is empty skipping...' % username)
 
        return user
 

	
 
    def settings(self):
 
        """
 
        Return a list of the form:
 
        [
 
            {
 
                "name": "OPTION_NAME",
 
                "type": "[bool|password|string|int|select]",
 
                ["values": ["opt1", "opt2", ...]]
 
                "validator": "expr"
 
                "description": "A short description of the option" [,
 
                "default": Default Value],
 
                ["formname": "Friendly Name for Forms"]
 
            } [, ...]
 
        ]
 

	
 
        This is used to interrogate the authentication plugin as to what
 
        settings it expects to be present and configured.
 

	
 
        'type' is a shorthand notation for what kind of value this option is.
 
        This is primarily used by the auth web form to control how the option
 
        is configured.
 
                bool : checkbox
 
                password : password input box
 
                string : input box
 
                select : single select dropdown
 

	
 
        'validator' is an lazy instantiated form field validator object, ala
 
        formencode. You need to *call* this object to init the validators.
 
        All calls to Kallithea validators should be used through self.validators
 
        which is a lazy loading proxy of formencode module.
 
        """
 
        raise NotImplementedError("Not implemented in base class")
 

	
 
    def plugin_settings(self):
 
        """
 
        This method is called by the authentication framework, not the .settings()
 
        method. This method adds a few default settings (e.g., "active"), so that
 
        plugin authors don't have to maintain a bunch of boilerplate.
 

	
 
        OVERRIDING THIS METHOD WILL CAUSE YOUR PLUGIN TO FAIL.
 
        """
 

	
 
        rcsettings = self.settings()
 
        rcsettings.insert(0, {
 
            "name": "enabled",
 
            "validator": self.validators.StringBoolean(if_missing=False),
 
            "type": "bool",
 
            "description": "Enable or Disable this Authentication Plugin",
 
            "formname": "Enabled"
 
            }
 
        )
 
        return rcsettings
 

	
 
    def user_activation_state(self):
 
        """
 
        Defines user activation state when creating new users
 

	
 
        :returns: boolean
 
        """
 
        raise NotImplementedError("Not implemented in base class")
 

	
 
    def auth(self, userobj, username, passwd, settings, **kwargs):
 
        """
 
        Given a user object (which may be null), username, a plaintext password,
 
        and a settings object (containing all the keys needed as listed in settings()),
 
        authenticate this user's login attempt.
 

	
 
        Return None on failure. On success, return a dictionary of the form:
 

	
 
            see: KallitheaAuthPluginBase.auth_func_attrs
 
        This is later validated for correctness
 
        """
 
        raise NotImplementedError("not implemented in base class")
 

	
 
    def _authenticate(self, userobj, username, passwd, settings, **kwargs):
 
        """
 
        Wrapper to call self.auth() that validates call on it
 

	
 
        :param userobj: userobj
 
        :param username: username
 
        :param passwd: plaintext password
 
        :param settings: plugin settings
 
        """
 
        auth = self.auth(userobj, username, passwd, settings, **kwargs)
 
        if auth:
 
            return self._validate_auth_return(auth)
 
        return auth
 

	
 
    def _validate_auth_return(self, ret):
 
        if not isinstance(ret, dict):
 
            raise Exception('returned value from auth must be a dict')
 
        for k in self.auth_func_attrs:
 
            if k not in ret:
 
                raise Exception('Missing %s attribute from returned data' % k)
 
        return ret
 

	
 

	
 
class KallitheaExternalAuthPlugin(KallitheaAuthPluginBase):
 
    def use_fake_password(self):
 
        """
 
        Return a boolean that indicates whether or not we should set the user's
 
        password to a random value when it is authenticated by this plugin.
 
        If your plugin provides authentication, then you will generally want this.
 

	
 
        :returns: boolean
 
        """
 
        raise NotImplementedError("Not implemented in base class")
 

	
 
    def _authenticate(self, userobj, username, passwd, settings, **kwargs):
 
        auth = super(KallitheaExternalAuthPlugin, self)._authenticate(
 
            userobj, username, passwd, settings, **kwargs)
 
        if auth:
 
            # maybe plugin will clean the username ?
 
            # we should use the return value
 
            username = auth['username']
 
            # if user is not active from our extern type we should fail to authe
 
            # this can prevent from creating users in Kallithea when using
 
            # external authentication, but if it's inactive user we shouldn't
 
            # create that user anyway
 
            if auth['active_from_extern'] is False:
 
                log.warning("User %s authenticated against %s, but is inactive"
 
                            % (username, self.__module__))
 
                return None
 

	
 
            if self.use_fake_password():
 
                # Randomize the PW because we don't need it, but don't want
 
                # them blank either
 
                passwd = PasswordGenerator().gen_password(length=8)
 

	
 
            log.debug('Updating or creating user info from %s plugin'
 
                      % self.name)
 
            user = UserModel().create_or_update(
 
                username=username,
 
                password=passwd,
 
                email=auth["email"],
 
                firstname=auth["firstname"],
 
                lastname=auth["lastname"],
 
                active=auth["active"],
 
                admin=auth["admin"],
 
                extern_name=auth["extern_name"],
 
                extern_type=self.name
 
            )
 
            Session().flush()
 
            # enforce user is just in given groups, all of them has to be ones
 
            # created from plugins. We store this info in _group_data JSON field
 
            try:
 
                groups = auth['groups'] or []
 
                UserGroupModel().enforce_groups(user, groups, self.name)
 
            except Exception:
 
                # for any reason group syncing fails, we should proceed with login
 
                log.error(traceback.format_exc())
 
            groups = auth['groups'] or []
 
            UserGroupModel().enforce_groups(user, groups, self.name)
 
            Session().commit()
 
        return auth
 

	
 

	
 
def importplugin(plugin):
 
    """
 
    Imports and returns the authentication plugin in the module named by plugin
 
    (e.g., plugin='kallithea.lib.auth_modules.auth_internal'). Returns the
 
    KallitheaAuthPluginBase subclass on success, raises exceptions on failure.
 

	
 
    raises:
 
        AttributeError -- no KallitheaAuthPlugin class in the module
 
        TypeError -- if the KallitheaAuthPlugin is not a subclass of ours KallitheaAuthPluginBase
 
        ImportError -- if we couldn't import the plugin at all
 
    """
 
    log.debug("Importing %s" % plugin)
 
    if not plugin.startswith(u'kallithea.lib.auth_modules.auth_'):
 
        parts = plugin.split(u'.lib.auth_modules.auth_', 1)
 
        if len(parts) == 2:
 
            _module, pn = parts
 
            if pn == EXTERN_TYPE_INTERNAL:
 
                pn = "internal"
 
            plugin = u'kallithea.lib.auth_modules.auth_' + pn
 
    PLUGIN_CLASS_NAME = "KallitheaAuthPlugin"
 
    try:
 
        module = importlib.import_module(plugin)
 
    except (ImportError, TypeError):
 
        log.error(traceback.format_exc())
 
        # TODO: make this more error prone, if by some accident we screw up
 
        # the plugin name, the crash is preatty bad and hard to recover
 
        raise
 

	
 
    log.debug("Loaded auth plugin from %s (module:%s, file:%s)"
 
              % (plugin, module.__name__, module.__file__))
 

	
 
    pluginclass = getattr(module, PLUGIN_CLASS_NAME)
 
    if not issubclass(pluginclass, KallitheaAuthPluginBase):
 
        raise TypeError("Authentication class %s.KallitheaAuthPlugin is not "
 
                        "a subclass of %s" % (plugin, KallitheaAuthPluginBase))
 
    return pluginclass
 

	
 

	
 
def loadplugin(plugin):
 
    """
 
    Loads and returns an instantiated authentication plugin.
 

	
 
        see: importplugin
 
    """
 
    plugin = importplugin(plugin)()
 
    if plugin.plugin_settings.im_func != KallitheaAuthPluginBase.plugin_settings.im_func:
 
        raise TypeError("Authentication class %s.KallitheaAuthPluginBase "
 
                        "has overriden the plugin_settings method, which is "
 
                        "forbidden." % plugin)
 
    return plugin
 

	
 

	
 
def authenticate(username, password, environ=None):
 
    """
 
    Authentication function used for access control,
 
    It tries to authenticate based on enabled authentication modules.
 

	
 
    :param username: username can be empty for container auth
 
    :param password: password can be empty for container auth
 
    :param environ: environ headers passed for container auth
 
    :returns: None if auth failed, plugin_user dict if auth is correct
 
    """
 

	
 
    auth_plugins = Setting.get_auth_plugins()
 
    log.debug('Authentication against %s plugins' % (auth_plugins,))
 
    for module in auth_plugins:
 
        try:
 
            plugin = loadplugin(module)
 
        except (ImportError, AttributeError, TypeError), e:
 
            raise ImportError('Failed to load authentication module %s : %s'
 
                              % (module, str(e)))
 
        log.debug('Trying authentication using ** %s **' % (module,))
 
        # load plugin settings from Kallithea database
 
        plugin_name = plugin.name
 
        plugin_settings = {}
 
        for v in plugin.plugin_settings():
 
            conf_key = "auth_%s_%s" % (plugin_name, v["name"])
 
            setting = Setting.get_by_name(conf_key)
 
            plugin_settings[v["name"]] = setting.app_settings_value if setting else None
 
        log.debug('Plugin settings \n%s' % formatted_json(plugin_settings))
 

	
 
        if not str2bool(plugin_settings["enabled"]):
 
            log.info("Authentication plugin %s is disabled, skipping for %s"
 
                     % (module, username))
 
            continue
 

	
 
        # use plugin's method of user extraction.
 
        user = plugin.get_user(username, environ=environ,
 
                               settings=plugin_settings)
 
        log.debug('Plugin %s extracted user is `%s`' % (module, user))
 
        if not plugin.accepts(user):
 
            log.debug('Plugin %s does not accept user `%s` for authentication'
 
                      % (module, user))
 
            continue
 
        else:
 
            log.debug('Plugin %s accepted user `%s` for authentication'
 
                      % (module, user))
 

	
 
        log.info('Authenticating user using %s plugin' % plugin.__module__)
 
        # _authenticate is a wrapper for .auth() method of plugin.
 
        # it checks if .auth() sends proper data. for KallitheaExternalAuthPlugin
 
        # it also maps users to Database and maps the attributes returned
 
        # from .auth() to Kallithea database. If this function returns data
 
        # then auth is correct.
 
        plugin_user = plugin._authenticate(user, username, password,
 
                                           plugin_settings,
 
                                           environ=environ or {})
 
        log.debug('PLUGIN USER DATA: %s' % plugin_user)
 

	
 
        if plugin_user:
 
            log.debug('Plugin returned proper authentication data')
 
            return plugin_user
 

	
 
        # we failed to Auth because .auth() method didn't return proper the user
 
        if username:
 
            log.warning("User `%s` failed to authenticate against %s"
 
                        % (username, plugin.__module__))
 
    return None
kallithea/lib/auth_modules/auth_pam.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.lib.auth_pam
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Kallithea authentication library for PAM
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Created on Apr 09, 2013
 
:author: Alexey Larikov
 
"""
 

	
 
import logging
 
import time
 
import pam
 
import pwd
 
import grp
 
import re
 
import socket
 
import threading
 

	
 
from kallithea.lib import auth_modules
 
from kallithea.lib.compat import formatted_json, hybrid_property
 

	
 
log = logging.getLogger(__name__)
 

	
 
# Cache to store PAM authenticated users
 
_auth_cache = dict()
 
_pam_lock = threading.Lock()
 

	
 

	
 
class KallitheaAuthPlugin(auth_modules.KallitheaExternalAuthPlugin):
 
    # PAM authnetication can be slow. Repository operations involve a lot of
 
    # auth calls. Little caching helps speedup push/pull operations significantly
 
    AUTH_CACHE_TTL = 4
 

	
 
    def __init__(self):
 
        global _auth_cache
 
        ts = time.time()
 
        cleared_cache = dict(
 
            [(k, v) for (k, v) in _auth_cache.items() if
 
             (v + KallitheaAuthPlugin.AUTH_CACHE_TTL > ts)])
 
        _auth_cache = cleared_cache
 

	
 
    @hybrid_property
 
    def name(self):
 
        return "pam"
 

	
 
    def settings(self):
 
        settings = [
 
            {
 
                "name": "service",
 
                "validator": self.validators.UnicodeString(strip=True),
 
                "type": "string",
 
                "description": "PAM service name to use for authentication",
 
                "default": "login",
 
                "formname": "PAM service name"
 
            },
 
            {
 
                "name": "gecos",
 
                "validator": self.validators.UnicodeString(strip=True),
 
                "type": "string",
 
                "description": "Regex for extracting user name/email etc "
 
                               "from Unix userinfo",
 
                "default": "(?P<last_name>.+),\s*(?P<first_name>\w+)",
 
                "formname": "Gecos Regex"
 
            }
 
        ]
 
        return settings
 

	
 
    def use_fake_password(self):
 
        return True
 

	
 
    def auth(self, userobj, username, password, settings, **kwargs):
 
        if username not in _auth_cache:
 
            # Need lock here, as PAM authentication is not thread safe
 
            _pam_lock.acquire()
 
            try:
 
                auth_result = pam.authenticate(username, password,
 
                                               settings["service"])
 
                # cache result only if we properly authenticated
 
                if auth_result:
 
                    _auth_cache[username] = time.time()
 
            finally:
 
                _pam_lock.release()
 

	
 
            if not auth_result:
 
                log.error("PAM was unable to authenticate user: %s" % (username,))
 
                return None
 
        else:
 
            log.debug("Using cached auth for user: %s" % (username,))
 

	
 
        # old attrs fetched from Kallithea database
 
        admin = getattr(userobj, 'admin', False)
 
        active = getattr(userobj, 'active', True)
 
        email = getattr(userobj, 'email', '') or "%s@%s" % (username, socket.gethostname())
 
        firstname = getattr(userobj, 'firstname', '')
 
        lastname = getattr(userobj, 'lastname', '')
 
        extern_type = getattr(userobj, 'extern_type', '')
 

	
 
        user_attrs = {
 
            'username': username,
 
            'firstname': firstname,
 
            'lastname': lastname,
 
            'groups': [g.gr_name for g in grp.getgrall() if username in g.gr_mem],
 
            'email': email,
 
            'admin': admin,
 
            'active': active,
 
            "active_from_extern": None,
 
            'extern_name': username,
 
            'extern_type': extern_type,
 
        }
 

	
 
        try:
 
            user_data = pwd.getpwnam(username)
 
            regex = settings["gecos"]
 
            match = re.search(regex, user_data.pw_gecos)
 
            if match:
 
                user_attrs["firstname"] = match.group('first_name')
 
                user_attrs["lastname"] = match.group('last_name')
 
        except Exception:
 
            log.warning("Cannot extract additional info for PAM user")
 
            log.warning("Cannot extract additional info for PAM user %s", username)
 
            pass
 

	
 
        log.debug("pamuser: \n%s" % formatted_json(user_attrs))
 
        log.info('user %s authenticated correctly' % user_attrs['username'])
 
        return user_attrs
kallithea/lib/celerylib/tasks.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.lib.celerylib.tasks
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Kallithea task modules, containing all task that suppose to be run
 
by celery daemon
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Oct 6, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
from celery.decorators import task
 

	
 
import os
 
import traceback
 
import logging
 
from os.path import join as jn
 

	
 
from time import mktime
 
from operator import itemgetter
 
from string import lower
 

	
 
from pylons import config
 

	
 
from kallithea import CELERY_ON
 
from kallithea.lib.celerylib import run_task, locked_task, dbsession, \
 
    str2bool, __get_lockkey, LockHeld, DaemonLock, get_session
 
from kallithea.lib.helpers import person
 
from kallithea.lib.rcmail.smtp_mailer import SmtpMailer
 
from kallithea.lib.utils import add_cache, action_logger
 
from kallithea.lib.compat import json, OrderedDict
 
from kallithea.lib.hooks import log_create_repository
 

	
 
from kallithea.model.db import Statistics, Repository, User
 

	
 

	
 
add_cache(config)  # pragma: no cover
 

	
 
__all__ = ['whoosh_index', 'get_commits_stats', 'send_email']
 

	
 

	
 
def get_logger(cls):
 
    if CELERY_ON:
 
        try:
 
            log = cls.get_logger()
 
        except Exception:
 
            log = logging.getLogger(__name__)
 
    else:
 
        log = logging.getLogger(__name__)
 

	
 
    return log
 
            return cls.get_logger()
 
        except AttributeError:
 
            pass
 
    return logging.getLogger(__name__)
 

	
 

	
 
@task(ignore_result=True)
 
@locked_task
 
@dbsession
 
def whoosh_index(repo_location, full_index):
 
    from kallithea.lib.indexers.daemon import WhooshIndexingDaemon
 
    DBS = get_session()
 

	
 
    index_location = config['index_dir']
 
    WhooshIndexingDaemon(index_location=index_location,
 
                         repo_location=repo_location, sa=DBS)\
 
                         .run(full_index=full_index)
 

	
 

	
 
@task(ignore_result=True)
 
@dbsession
 
def get_commits_stats(repo_name, ts_min_y, ts_max_y, recurse_limit=100):
 
    log = get_logger(get_commits_stats)
 
    DBS = get_session()
 
    lockkey = __get_lockkey('get_commits_stats', repo_name, ts_min_y,
 
                            ts_max_y)
 
    lockkey_path = config['app_conf']['cache_dir']
 

	
 
    log.info('running task with lockkey %s' % lockkey)
 

	
 
    try:
 
        lock = l = DaemonLock(file_=jn(lockkey_path, lockkey))
 

	
 
        # for js data compatibility cleans the key for person from '
 
        akc = lambda k: person(k).replace('"', "")
 

	
 
        co_day_auth_aggr = {}
 
        commits_by_day_aggregate = {}
 
        repo = Repository.get_by_repo_name(repo_name)
 
        if repo is None:
 
            return True
 

	
 
        repo = repo.scm_instance
 
        repo_size = repo.count()
 
        # return if repo have no revisions
 
        if repo_size < 1:
 
            lock.release()
 
            return True
 

	
 
        skip_date_limit = True
 
        parse_limit = int(config['app_conf'].get('commit_parse_limit'))
 
        last_rev = None
 
        last_cs = None
 
        timegetter = itemgetter('time')
 

	
 
        dbrepo = DBS.query(Repository)\
 
            .filter(Repository.repo_name == repo_name).scalar()
 
        cur_stats = DBS.query(Statistics)\
 
            .filter(Statistics.repository == dbrepo).scalar()
 

	
 
        if cur_stats is not None:
 
            last_rev = cur_stats.stat_on_revision
 

	
 
        if last_rev == repo.get_changeset().revision and repo_size > 1:
 
            # pass silently without any work if we're not on first revision or
 
            # current state of parsing revision(from db marker) is the
 
            # last revision
 
            lock.release()
 
            return True
 

	
 
        if cur_stats:
 
            commits_by_day_aggregate = OrderedDict(json.loads(
 
                                        cur_stats.commit_activity_combined))
 
            co_day_auth_aggr = json.loads(cur_stats.commit_activity)
 

	
 
        log.debug('starting parsing %s' % parse_limit)
 
        lmktime = mktime
 

	
 
        last_rev = last_rev + 1 if last_rev >= 0 else 0
 
        log.debug('Getting revisions from %s to %s' % (
 
             last_rev, last_rev + parse_limit)
 
        )
 
        for cs in repo[last_rev:last_rev + parse_limit]:
 
            log.debug('parsing %s' % cs)
 
            last_cs = cs  # remember last parsed changeset
 
            k = lmktime([cs.date.timetuple()[0], cs.date.timetuple()[1],
 
                          cs.date.timetuple()[2], 0, 0, 0, 0, 0, 0])
 

	
 
            if akc(cs.author) in co_day_auth_aggr:
 
                try:
 
                    l = [timegetter(x) for x in
 
                         co_day_auth_aggr[akc(cs.author)]['data']]
 
                    time_pos = l.index(k)
 
                except ValueError:
 
                    time_pos = None
 

	
 
                if time_pos >= 0 and time_pos is not None:
 

	
 
                    datadict = \
 
                        co_day_auth_aggr[akc(cs.author)]['data'][time_pos]
 

	
 
                    datadict["commits"] += 1
 
                    datadict["added"] += len(cs.added)
 
                    datadict["changed"] += len(cs.changed)
 
                    datadict["removed"] += len(cs.removed)
 

	
 
                else:
 
                    if k >= ts_min_y and k <= ts_max_y or skip_date_limit:
 

	
 
                        datadict = {"time": k,
 
                                    "commits": 1,
 
                                    "added": len(cs.added),
 
                                    "changed": len(cs.changed),
 
                                    "removed": len(cs.removed),
 
                                   }
 
                        co_day_auth_aggr[akc(cs.author)]['data']\
 
                            .append(datadict)
 

	
 
            else:
 
                if k >= ts_min_y and k <= ts_max_y or skip_date_limit:
 
                    co_day_auth_aggr[akc(cs.author)] = {
 
                                        "label": akc(cs.author),
 
                                        "data": [{"time":k,
 
                                                 "commits":1,
 
                                                 "added":len(cs.added),
 
                                                 "changed":len(cs.changed),
 
                                                 "removed":len(cs.removed),
 
                                                 }],
 
                                        "schema": ["commits"],
 
                                        }
 

	
 
            #gather all data by day
 
            if k in commits_by_day_aggregate:
 
                commits_by_day_aggregate[k] += 1
 
            else:
 
                commits_by_day_aggregate[k] = 1
 

	
 
        overview_data = sorted(commits_by_day_aggregate.items(),
 
                               key=itemgetter(0))
 

	
 
        if not co_day_auth_aggr:
 
            co_day_auth_aggr[akc(repo.contact)] = {
 
                "label": akc(repo.contact),
 
                "data": [0, 1],
 
                "schema": ["commits"],
 
            }
 

	
 
        stats = cur_stats if cur_stats else Statistics()
 
        stats.commit_activity = json.dumps(co_day_auth_aggr)
 
        stats.commit_activity_combined = json.dumps(overview_data)
 

	
 
        log.debug('last revison %s' % last_rev)
 
        leftovers = len(repo.revisions[last_rev:])
 
        log.debug('revisions to parse %s' % leftovers)
 

	
 
        if last_rev == 0 or leftovers < parse_limit:
 
            log.debug('getting code trending stats')
 
            stats.languages = json.dumps(__get_codes_stats(repo_name))
 

	
 
        try:
 
            stats.repository = dbrepo
 
            stats.stat_on_revision = last_cs.revision if last_cs else 0
 
            DBS.add(stats)
 
            DBS.commit()
 
        except:
 
            log.error(traceback.format_exc())
 
            DBS.rollback()
 
            lock.release()
 
            return False
 

	
 
        # final release
 
        lock.release()
 

	
 
        # execute another task if celery is enabled
 
        if len(repo.revisions) > 1 and CELERY_ON and recurse_limit > 0:
 
            recurse_limit -= 1
 
            run_task(get_commits_stats, repo_name, ts_min_y, ts_max_y,
 
                     recurse_limit)
 
        if recurse_limit <= 0:
 
            log.debug('Breaking recursive mode due to reach of recurse limit')
 
        return True
 
    except LockHeld:
 
        log.info('LockHeld')
 
        return 'Task with key %s already running' % lockkey
 

	
 

	
 
@task(ignore_result=True)
 
@dbsession
 
def send_email(recipients, subject, body='', html_body='', headers=None):
 
    """
 
    Sends an email with defined parameters from the .ini files.
 

	
 
    :param recipients: list of recipients, if this is None, the defined email
 
        address from field 'email_to' and all admins is used instead
 
    :param subject: subject of the mail
 
    :param body: body of the mail
 
    :param html_body: html version of body
 
    """
 
    log = get_logger(send_email)
 
    assert isinstance(recipients, list), recipients
 

	
 
    email_config = config
 
    email_prefix = email_config.get('email_prefix', '')
 
    if email_prefix:
 
        subject = "%s %s" % (email_prefix, subject)
 
    if recipients is None:
 
        # if recipients are not defined we send to email_config + all admins
 
        admins = [u.email for u in User.query()
 
                  .filter(User.admin == True).all()]
 
        recipients = [email_config.get('email_to')] + admins
 
        log.warning("recipients not specified for '%s' - sending to admins %s", subject, ' '.join(recipients))
 
    elif not recipients:
 
        log.error("No recipients specified")
 
        return False
 

	
 
    mail_from = email_config.get('app_email_from', 'Kallithea')
 
    user = email_config.get('smtp_username')
 
    passwd = email_config.get('smtp_password')
 
    mail_server = email_config.get('smtp_server')
 
    mail_port = email_config.get('smtp_port')
 
    tls = str2bool(email_config.get('smtp_use_tls'))
 
    ssl = str2bool(email_config.get('smtp_use_ssl'))
 
    debug = str2bool(email_config.get('debug'))
 
    smtp_auth = email_config.get('smtp_auth')
 

	
 
    if not mail_server:
 
        log.error("SMTP mail server not configured - cannot send mail '%s' to %s", subject, ' '.join(recipients))
 
        log.warning("body:\n%s", body)
 
        log.warning("html:\n%s", html_body)
 
        return False
 

	
 
    try:
 
        m = SmtpMailer(mail_from, user, passwd, mail_server, smtp_auth,
 
                       mail_port, ssl, tls, debug=debug)
 
        m.send(recipients, subject, body, html_body, headers=headers)
 
    except:
 
        log.error('Mail sending failed')
 
        log.error(traceback.format_exc())
 
        return False
 
    return True
 

	
 
@task(ignore_result=False)
 
@dbsession
 
def create_repo(form_data, cur_user):
 
    from kallithea.model.repo import RepoModel
 
    from kallithea.model.user import UserModel
 
    from kallithea.model.db import Setting
 

	
 
    log = get_logger(create_repo)
 
    DBS = get_session()
 

	
 
    cur_user = UserModel(DBS)._get_user(cur_user)
 

	
 
    owner = cur_user
 
    repo_name = form_data['repo_name']
 
    repo_name_full = form_data['repo_name_full']
 
    repo_type = form_data['repo_type']
 
    description = form_data['repo_description']
 
    private = form_data['repo_private']
 
    clone_uri = form_data.get('clone_uri')
 
    repo_group = form_data['repo_group']
 
    landing_rev = form_data['repo_landing_rev']
 
    copy_fork_permissions = form_data.get('copy_permissions')
 
    copy_group_permissions = form_data.get('repo_copy_permissions')
 
    fork_of = form_data.get('fork_parent_id')
 
    state = form_data.get('repo_state', Repository.STATE_PENDING)
 

	
 
    # repo creation defaults, private and repo_type are filled in form
 
    defs = Setting.get_default_repo_settings(strip_prefix=True)
 
    enable_statistics = defs.get('repo_enable_statistics')
 
    enable_locking = defs.get('repo_enable_locking')
 
    enable_downloads = defs.get('repo_enable_downloads')
 

	
 
    try:
 
        repo = RepoModel(DBS)._create_repo(
 
            repo_name=repo_name_full,
 
            repo_type=repo_type,
 
            description=description,
 
            owner=owner,
 
            private=private,
 
            clone_uri=clone_uri,
 
            repo_group=repo_group,
 
            landing_rev=landing_rev,
 
            fork_of=fork_of,
 
            copy_fork_permissions=copy_fork_permissions,
 
            copy_group_permissions=copy_group_permissions,
 
            enable_statistics=enable_statistics,
 
            enable_locking=enable_locking,
 
            enable_downloads=enable_downloads,
 
            state=state
 
        )
 

	
 
        action_logger(cur_user, 'user_created_repo',
 
                      form_data['repo_name_full'], '', DBS)
 

	
 
        DBS.commit()
 
        # now create this repo on Filesystem
 
        RepoModel(DBS)._create_filesystem_repo(
 
            repo_name=repo_name,
 
            repo_type=repo_type,
 
            repo_group=RepoModel(DBS)._get_repo_group(repo_group),
 
            clone_uri=clone_uri,
 
        )
 
        repo = Repository.get_by_repo_name(repo_name_full)
 
        log_create_repository(repo.get_dict(), created_by=owner.username)
 

	
 
        # update repo changeset caches initially
 
        repo.update_changeset_cache()
 

	
 
        # set new created state
 
        repo.set_state(Repository.STATE_CREATED)
 
        DBS.commit()
 
    except Exception, e:
 
        log.warning('Exception %s occured when forking repository, '
 
                    'doing cleanup...' % e)
 
        # rollback things manually !
 
        repo = Repository.get_by_repo_name(repo_name_full)
 
        if repo:
 
            Repository.delete(repo.repo_id)
 
            DBS.commit()
 
            RepoModel(DBS)._delete_filesystem_repo(repo)
 
        raise
 

	
 
    # it's an odd fix to make celery fail task when exception occurs
 
    def on_failure(self, *args, **kwargs):
 
        pass
 

	
 
    return True
 

	
 

	
 
@task(ignore_result=False)
 
@dbsession
 
def create_repo_fork(form_data, cur_user):
 
    """
 
    Creates a fork of repository using interval VCS methods
 

	
 
    :param form_data:
 
    :param cur_user:
 
    """
 
    from kallithea.model.repo import RepoModel
 
    from kallithea.model.user import UserModel
 

	
 
    log = get_logger(create_repo_fork)
 
    DBS = get_session()
 

	
 
    base_path = Repository.base_path()
 
    cur_user = UserModel(DBS)._get_user(cur_user)
 

	
 
    repo_name = form_data['repo_name']  # fork in this case
 
    repo_name_full = form_data['repo_name_full']
 

	
 
    repo_type = form_data['repo_type']
 
    owner = cur_user
 
    private = form_data['private']
 
    clone_uri = form_data.get('clone_uri')
 
    repo_group = form_data['repo_group']
 
    landing_rev = form_data['landing_rev']
 
    copy_fork_permissions = form_data.get('copy_permissions')
 

	
 
    try:
 
        fork_of = RepoModel(DBS)._get_repo(form_data.get('fork_parent_id'))
 

	
 
        RepoModel(DBS)._create_repo(
 
            repo_name=repo_name_full,
 
            repo_type=repo_type,
 
            description=form_data['description'],
 
            owner=owner,
 
            private=private,
 
            clone_uri=clone_uri,
 
            repo_group=repo_group,
 
            landing_rev=landing_rev,
 
            fork_of=fork_of,
 
            copy_fork_permissions=copy_fork_permissions
 
        )
 
        action_logger(cur_user, 'user_forked_repo:%s' % repo_name_full,
 
                      fork_of.repo_name, '', DBS)
 
        DBS.commit()
 

	
 
        update_after_clone = form_data['update_after_clone'] # FIXME - unused!
 
        source_repo_path = os.path.join(base_path, fork_of.repo_name)
 

	
 
        # now create this repo on Filesystem
 
        RepoModel(DBS)._create_filesystem_repo(
 
            repo_name=repo_name,
 
            repo_type=repo_type,
 
            repo_group=RepoModel(DBS)._get_repo_group(repo_group),
 
            clone_uri=source_repo_path,
 
        )
 
        repo = Repository.get_by_repo_name(repo_name_full)
 
        log_create_repository(repo.get_dict(), created_by=owner.username)
 

	
 
        # update repo changeset caches initially
 
        repo.update_changeset_cache()
 

	
 
        # set new created state
 
        repo.set_state(Repository.STATE_CREATED)
 
        DBS.commit()
 
    except Exception, e:
 
        log.warning('Exception %s occured when forking repository, '
 
                    'doing cleanup...' % e)
 
        #rollback things manually !
 
        repo = Repository.get_by_repo_name(repo_name_full)
 
        if repo:
 
            Repository.delete(repo.repo_id)
 
            DBS.commit()
 
            RepoModel(DBS)._delete_filesystem_repo(repo)
 
        raise
 

	
 
    # it's an odd fix to make celery fail task when exception occurs
 
    def on_failure(self, *args, **kwargs):
 
        pass
 

	
 
    return True
 

	
 

	
 
def __get_codes_stats(repo_name):
 
    from kallithea.config.conf import LANGUAGES_EXTENSIONS_MAP
 
    repo = Repository.get_by_repo_name(repo_name).scm_instance
 

	
 
    tip = repo.get_changeset()
 
    code_stats = {}
 

	
 
    def aggregate(cs):
 
        for f in cs[2]:
 
            ext = lower(f.extension)
 
            if ext in LANGUAGES_EXTENSIONS_MAP.keys() and not f.is_binary:
 
                if ext in code_stats:
 
                    code_stats[ext] += 1
 
                else:
 
                    code_stats[ext] = 1
 

	
 
    map(aggregate, tip.walk('/'))
 

	
 
    return code_stats or {}
kallithea/lib/db_manage.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.lib.db_manage
 
~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Database creation, and setup module for Kallithea. Used for creation
 
of database as well as for migration operations
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 10, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
import os
 
import sys
 
import time
 
import uuid
 
import logging
 
from os.path import dirname as dn, join as jn
 

	
 
from kallithea import __dbversion__, __py_version__, EXTERN_TYPE_INTERNAL, DB_MIGRATIONS
 
from kallithea.model.user import UserModel
 
from kallithea.lib.utils import ask_ok
 
from kallithea.model import init_model
 
from kallithea.model.db import User, Permission, Ui, \
 
    Setting, UserToPerm, DbMigrateVersion, RepoGroup, \
 
    UserRepoGroupToPerm, CacheInvalidation, UserGroup, Repository
 

	
 
from sqlalchemy.engine import create_engine
 
from kallithea.model.repo_group import RepoGroupModel
 
#from kallithea.model import meta
 
from kallithea.model.meta import Session, Base
 
from kallithea.model.repo import RepoModel
 
from kallithea.model.permission import PermissionModel
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
def notify(msg):
 
    """
 
    Notification for migrations messages
 
    """
 
    ml = len(msg) + (4 * 2)
 
    print('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper()
 

	
 

	
 
class DbManage(object):
 
    def __init__(self, log_sql, dbconf, root, tests=False, SESSION=None, cli_args={}):
 
        self.dbname = dbconf.split('/')[-1]
 
        self.tests = tests
 
        self.root = root
 
        self.dburi = dbconf
 
        self.log_sql = log_sql
 
        self.db_exists = False
 
        self.cli_args = cli_args
 
        self.init_db(SESSION=SESSION)
 

	
 
        force_ask = self.cli_args.get('force_ask')
 
        if force_ask is not None:
 
            global ask_ok
 
            ask_ok = lambda *args, **kwargs: force_ask
 

	
 
    def init_db(self, SESSION=None):
 
        if SESSION:
 
            self.sa = SESSION
 
        else:
 
            #init new sessions
 
            engine = create_engine(self.dburi, echo=self.log_sql)
 
            init_model(engine)
 
            self.sa = Session()
 

	
 
    def create_tables(self, override=False):
 
        """
 
        Create a auth database
 
        """
 

	
 
        log.info("Any existing database is going to be destroyed")
 
        if self.tests:
 
            destroy = True
 
        else:
 
            destroy = ask_ok('Are you sure to destroy old database ? [y/n]')
 
        if not destroy:
 
            print 'Nothing done.'
 
            sys.exit(0)
 
        if destroy:
 
            Base.metadata.drop_all()
 

	
 
        checkfirst = not override
 
        Base.metadata.create_all(checkfirst=checkfirst)
 
        log.info('Created tables for %s' % self.dbname)
 

	
 
    def set_db_version(self):
 
        ver = DbMigrateVersion()
 
        ver.version = __dbversion__
 
        ver.repository_id = DB_MIGRATIONS
 
        ver.repository_path = 'versions'
 
        self.sa.add(ver)
 
        log.info('db version set to: %s' % __dbversion__)
 

	
 
    def upgrade(self):
 
        """
 
        Upgrades given database schema to given revision following
 
        all needed steps, to perform the upgrade
 

	
 
        """
 

	
 
        from kallithea.lib.dbmigrate.migrate.versioning import api
 
        from kallithea.lib.dbmigrate.migrate.exceptions import \
 
            DatabaseNotControlledError
 

	
 
        if 'sqlite' in self.dburi:
 
            print (
 
               '********************** WARNING **********************\n'
 
               'Make sure your version of sqlite is at least 3.7.X.  \n'
 
               'Earlier versions are known to fail on some migrations\n'
 
               '*****************************************************\n')
 

	
 
        upgrade = ask_ok('You are about to perform database upgrade, make '
 
                         'sure You backed up your database before. '
 
                         'Continue ? [y/n]')
 
        if not upgrade:
 
            print 'No upgrade performed'
 
            sys.exit(0)
 

	
 
        repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
 
                             'kallithea/lib/dbmigrate')
 
        db_uri = self.dburi
 

	
 
        try:
 
            curr_version = api.db_version(db_uri, repository_path)
 
            msg = ('Found current database under version '
 
                   'control with version %s' % curr_version)
 

	
 
        except (RuntimeError, DatabaseNotControlledError):
 
            curr_version = 1
 
            msg = ('Current database is not under version control. Setting '
 
                   'as version %s' % curr_version)
 
            api.version_control(db_uri, repository_path, curr_version)
 

	
 
        notify(msg)
 
        if curr_version == __dbversion__:
 
            print 'This database is already at the newest version'
 
            sys.exit(0)
 

	
 
        # clear cache keys
 
        log.info("Clearing cache keys now...")
 
        CacheInvalidation.clear_cache()
 

	
 
        upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
 
        notify('attempting to do database upgrade from '
 
               'version %s to version %s' % (curr_version, __dbversion__))
 

	
 
        # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
 
        _step = None
 
        for step in upgrade_steps:
 
            notify('performing upgrade step %s' % step)
 
            time.sleep(0.5)
 

	
 
            api.upgrade(db_uri, repository_path, step)
 
            notify('schema upgrade for step %s completed' % (step,))
 

	
 
            _step = step
 

	
 
        notify('upgrade to version %s successful' % _step)
 

	
 
    def fix_repo_paths(self):
 
        """
 
        Fixes a old kallithea version path into new one without a '*'
 
        """
 

	
 
        paths = self.sa.query(Ui)\
 
                .filter(Ui.ui_key == '/')\
 
                .scalar()
 

	
 
        paths.ui_value = paths.ui_value.replace('*', '')
 

	
 
        try:
 
            self.sa.add(paths)
 
            self.sa.commit()
 
        except Exception:
 
            self.sa.rollback()
 
            raise
 
        self.sa.add(paths)
 
        self.sa.commit()
 

	
 
    def fix_default_user(self):
 
        """
 
        Fixes a old default user with some 'nicer' default values,
 
        used mostly for anonymous access
 
        """
 
        def_user = self.sa.query(User)\
 
                .filter(User.username == User.DEFAULT_USER)\
 
                .one()
 

	
 
        def_user.name = 'Anonymous'
 
        def_user.lastname = 'User'
 
        def_user.email = 'anonymous@kallithea-scm.org'
 

	
 
        try:
 
            self.sa.add(def_user)
 
            self.sa.commit()
 
        except Exception:
 
            self.sa.rollback()
 
            raise
 
        self.sa.add(def_user)
 
        self.sa.commit()
 

	
 
    def fix_settings(self):
 
        """
 
        Fixes kallithea settings adds ga_code key for google analytics
 
        """
 

	
 
        hgsettings3 = Setting('ga_code', '')
 

	
 
        try:
 
            self.sa.add(hgsettings3)
 
            self.sa.commit()
 
        except Exception:
 
            self.sa.rollback()
 
            raise
 
        self.sa.add(hgsettings3)
 
        self.sa.commit()
 

	
 
    def admin_prompt(self, second=False):
 
        if not self.tests:
 
            import getpass
 

	
 
            # defaults
 
            defaults = self.cli_args
 
            username = defaults.get('username')
 
            password = defaults.get('password')
 
            email = defaults.get('email')
 

	
 
            def get_password():
 
                password = getpass.getpass('Specify admin password '
 
                                           '(min 6 chars):')
 
                confirm = getpass.getpass('Confirm password:')
 

	
 
                if password != confirm:
 
                    log.error('passwords mismatch')
 
                    return False
 
                if len(password) < 6:
 
                    log.error('password is to short use at least 6 characters')
 
                    return False
 

	
 
                return password
 
            if username is None:
 
                username = raw_input('Specify admin username:')
 
            if password is None:
 
                password = get_password()
 
                if not password:
 
                    #second try
 
                    password = get_password()
 
                    if not password:
 
                        sys.exit()
 
            if email is None:
 
                email = raw_input('Specify admin email:')
 
            self.create_user(username, password, email, True)
 
        else:
 
            log.info('creating admin and regular test users')
 
            from kallithea.tests import TEST_USER_ADMIN_LOGIN, \
 
            TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
 
            TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
 
            TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
 
            TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
 

	
 
            self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
 
                             TEST_USER_ADMIN_EMAIL, True)
 

	
 
            self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
 
                             TEST_USER_REGULAR_EMAIL, False)
 

	
 
            self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
 
                             TEST_USER_REGULAR2_EMAIL, False)
 

	
 
    def create_ui_settings(self, repo_store_path):
 
        """
 
        Creates ui settings, fills out hooks
 
        """
 

	
 
        #HOOKS
 
        hooks1_key = Ui.HOOK_UPDATE
 
        hooks1_ = self.sa.query(Ui)\
 
            .filter(Ui.ui_key == hooks1_key).scalar()
 

	
 
        hooks1 = Ui() if hooks1_ is None else hooks1_
 
        hooks1.ui_section = 'hooks'
 
        hooks1.ui_key = hooks1_key
 
        hooks1.ui_value = 'hg update >&2'
 
        hooks1.ui_active = False
 
        self.sa.add(hooks1)
 

	
 
        hooks2_key = Ui.HOOK_REPO_SIZE
 
        hooks2_ = self.sa.query(Ui)\
 
            .filter(Ui.ui_key == hooks2_key).scalar()
 
        hooks2 = Ui() if hooks2_ is None else hooks2_
 
        hooks2.ui_section = 'hooks'
 
        hooks2.ui_key = hooks2_key
 
        hooks2.ui_value = 'python:kallithea.lib.hooks.repo_size'
 
        self.sa.add(hooks2)
 

	
 
        hooks3 = Ui()
 
        hooks3.ui_section = 'hooks'
 
        hooks3.ui_key = Ui.HOOK_PUSH
 
        hooks3.ui_value = 'python:kallithea.lib.hooks.log_push_action'
 
        self.sa.add(hooks3)
 

	
 
        hooks4 = Ui()
 
        hooks4.ui_section = 'hooks'
 
        hooks4.ui_key = Ui.HOOK_PRE_PUSH
 
        hooks4.ui_value = 'python:kallithea.lib.hooks.pre_push'
 
        self.sa.add(hooks4)
 

	
 
        hooks5 = Ui()
 
        hooks5.ui_section = 'hooks'
 
        hooks5.ui_key = Ui.HOOK_PULL
 
        hooks5.ui_value = 'python:kallithea.lib.hooks.log_pull_action'
 
        self.sa.add(hooks5)
 

	
 
        hooks6 = Ui()
 
        hooks6.ui_section = 'hooks'
 
        hooks6.ui_key = Ui.HOOK_PRE_PULL
 
        hooks6.ui_value = 'python:kallithea.lib.hooks.pre_pull'
 
        self.sa.add(hooks6)
 

	
 
        # enable largefiles
 
        largefiles = Ui()
 
        largefiles.ui_section = 'extensions'
 
        largefiles.ui_key = 'largefiles'
 
        largefiles.ui_value = ''
 
        self.sa.add(largefiles)
 

	
 
        # set default largefiles cache dir, defaults to
 
        # /repo location/.cache/largefiles
 
        largefiles = Ui()
 
        largefiles.ui_section = 'largefiles'
 
        largefiles.ui_key = 'usercache'
 
        largefiles.ui_value = os.path.join(repo_store_path, '.cache',
 
                                           'largefiles')
 
        self.sa.add(largefiles)
 

	
 
        # enable hgsubversion disabled by default
 
        hgsubversion = Ui()
 
        hgsubversion.ui_section = 'extensions'
 
        hgsubversion.ui_key = 'hgsubversion'
 
        hgsubversion.ui_value = ''
 
        hgsubversion.ui_active = False
 
        self.sa.add(hgsubversion)
 

	
 
        # enable hggit disabled by default
 
        hggit = Ui()
 
        hggit.ui_section = 'extensions'
 
        hggit.ui_key = 'hggit'
 
        hggit.ui_value = ''
 
        hggit.ui_active = False
 
        self.sa.add(hggit)
 

	
 
    def create_auth_plugin_options(self, skip_existing=False):
 
        """
 
        Create default auth plugin settings, and make it active
 

	
 
        :param skip_existing:
 
        """
 

	
 
        for k, v, t in [('auth_plugins', 'kallithea.lib.auth_modules.auth_internal', 'list'),
 
                     ('auth_internal_enabled', 'True', 'bool')]:
 
            if skip_existing and Setting.get_by_name(k) != None:
 
                log.debug('Skipping option %s' % k)
 
                continue
 
            setting = Setting(k, v, t)
 
            self.sa.add(setting)
 

	
 
    def create_default_options(self, skip_existing=False):
 
        """Creates default settings"""
 

	
 
        for k, v, t in [
 
            ('default_repo_enable_locking',  False, 'bool'),
 
            ('default_repo_enable_downloads', False, 'bool'),
 
            ('default_repo_enable_statistics', False, 'bool'),
 
            ('default_repo_private', False, 'bool'),
 
            ('default_repo_type', 'hg', 'unicode')]:
 

	
 
            if skip_existing and Setting.get_by_name(k) is not None:
 
                log.debug('Skipping option %s' % k)
 
                continue
 
            setting = Setting(k, v, t)
 
            self.sa.add(setting)
 

	
 
    def fixup_groups(self):
 
        def_usr = User.get_default_user()
 
        for g in RepoGroup.query().all():
 
            g.group_name = g.get_new_name(g.name)
 
            self.sa.add(g)
 
            # get default perm
 
            default = UserRepoGroupToPerm.query()\
 
                .filter(UserRepoGroupToPerm.group == g)\
 
                .filter(UserRepoGroupToPerm.user == def_usr)\
 
                .scalar()
 

	
 
            if default is None:
 
                log.debug('missing default permission for group %s adding' % g)
 
                perm_obj = RepoGroupModel()._create_default_perms(g)
 
                self.sa.add(perm_obj)
 

	
 
    def reset_permissions(self, username):
 
        """
 
        Resets permissions to default state, usefull when old systems had
 
        bad permissions, we must clean them up
 

	
 
        :param username:
 
        """
 
        default_user = User.get_by_username(username)
 
        if not default_user:
 
            return
 

	
 
        u2p = UserToPerm.query()\
 
            .filter(UserToPerm.user == default_user).all()
 
        fixed = False
 
        if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
 
            for p in u2p:
 
                Session().delete(p)
 
            fixed = True
 
            self.populate_default_permissions()
 
        return fixed
 

	
 
    def update_repo_info(self):
 
        RepoModel.update_repoinfo()
 

	
 
    def config_prompt(self, test_repo_path='', retries=3):
 
        defaults = self.cli_args
 
        _path = defaults.get('repos_location')
 
        if retries == 3:
 
            log.info('Setting up repositories config')
 

	
 
        if _path is not None:
 
            path = _path
 
        elif not self.tests and not test_repo_path:
 
            path = raw_input(
 
                 'Enter a valid absolute path to store repositories. '
 
                 'All repositories in that path will be added automatically:'
 
            )
 
        else:
 
            path = test_repo_path
 
        path_ok = True
 

	
 
        # check proper dir
 
        if not os.path.isdir(path):
 
            path_ok = False
 
            log.error('Given path %s is not a valid directory' % (path,))
 

	
 
        elif not os.path.isabs(path):
 
            path_ok = False
 
            log.error('Given path %s is not an absolute path' % (path,))
 

	
 
        # check if path is at least readable.
 
        if not os.access(path, os.R_OK):
 
            path_ok = False
 
            log.error('Given path %s is not readable' % (path,))
 

	
 
        # check write access, warn user about non writeable paths
 
        elif not os.access(path, os.W_OK) and path_ok:
 
            log.warning('No write permission to given path %s' % (path,))
 
            if not ask_ok('Given path %s is not writeable, do you want to '
 
                          'continue with read only mode ? [y/n]' % (path,)):
 
                log.error('Canceled by user')
 
                sys.exit(-1)
 

	
 
        if retries == 0:
 
            sys.exit('max retries reached')
 
        if not path_ok:
 
            retries -= 1
 
            return self.config_prompt(test_repo_path, retries)
 

	
 
        real_path = os.path.normpath(os.path.realpath(path))
 

	
 
        if real_path != os.path.normpath(path):
 
            log.warning('Using normalized path %s instead of %s' % (real_path, path))
 

	
 
        return real_path
 

	
 
    def create_settings(self, path):
 

	
 
        self.create_ui_settings(path)
 

	
 
        ui_config = [
 
            ('web', 'push_ssl', 'false'),
 
            ('web', 'allow_archive', 'gz zip bz2'),
 
            ('web', 'allow_push', '*'),
 
            ('web', 'baseurl', '/'),
 
            ('paths', '/', path),
 
            #('phases', 'publish', 'false')
 
        ]
 
        for section, key, value in ui_config:
 
            ui_conf = Ui()
 
            setattr(ui_conf, 'ui_section', section)
 
            setattr(ui_conf, 'ui_key', key)
 
            setattr(ui_conf, 'ui_value', value)
 
            self.sa.add(ui_conf)
 

	
 
        settings = [
 
            ('realm', 'Kallithea', 'unicode'),
 
            ('title', '', 'unicode'),
 
            ('ga_code', '', 'unicode'),
 
            ('show_public_icon', True, 'bool'),
 
            ('show_private_icon', True, 'bool'),
 
            ('stylify_metatags', False, 'bool'),
 
            ('dashboard_items', 100, 'int'),
 
            ('admin_grid_items', 25, 'int'),
 
            ('show_version', True, 'bool'),
 
            ('use_gravatar', True, 'bool'),
 
            ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
 
            ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
 
            ('update_url', Setting.DEFAULT_UPDATE_URL, 'unicode'),
 
        ]
 
        for key, val, type_ in settings:
 
            sett = Setting(key, val, type_)
 
            self.sa.add(sett)
 

	
 
        self.create_auth_plugin_options()
 
        self.create_default_options()
 

	
 
        log.info('created ui config')
 

	
 
    def create_user(self, username, password, email='', admin=False):
 
        log.info('creating user %s' % username)
 
        UserModel().create_or_update(username, password, email,
 
                                     firstname='Kallithea', lastname='Admin',
 
                                     active=True, admin=admin,
 
                                     extern_type=EXTERN_TYPE_INTERNAL)
 

	
 
    def create_default_user(self):
 
        log.info('creating default user')
 
        # create default user for handling default permissions.
 
        user = UserModel().create_or_update(username=User.DEFAULT_USER,
 
                                            password=str(uuid.uuid1())[:20],
 
                                            email='anonymous@kallithea-scm.org',
 
                                            firstname='Anonymous',
 
                                            lastname='User')
 
        # based on configuration options activate/deactive this user which
 
        # controlls anonymous access
 
        if self.cli_args.get('public_access') is False:
 
            log.info('Public access disabled')
 
            user.active = False
 
            Session().add(user)
 
            Session().commit()
 

	
 
    def create_permissions(self):
 
        """
 
        Creates all permissions defined in the system
 
        """
 
        # module.(access|create|change|delete)_[name]
 
        # module.(none|read|write|admin)
 
        log.info('creating permissions')
 
        PermissionModel(self.sa).create_permissions()
 

	
 
    def populate_default_permissions(self):
 
        """
 
        Populate default permissions. It will create only the default
 
        permissions that are missing, and not alter already defined ones
 
        """
 
        log.info('creating default user permissions')
 
        PermissionModel(self.sa).create_default_permissions(user=User.DEFAULT_USER)
 

	
 
    @staticmethod
 
    def check_waitress():
 
        """
 
        Function executed at the end of setup
 
        """
 
        if not __py_version__ >= (2, 6):
 
            notify('Python2.5 detected, please switch '
 
                   'egg:waitress#main -> egg:Paste#http '
 
                   'in your .ini file')
kallithea/lib/helpers.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
Helper functions
 

	
 
Consists of functions to typically be used within templates, but also
 
available to Controllers. This module is available to both as 'h'.
 
"""
 
import random
 
import hashlib
 
import StringIO
 
import math
 
import logging
 
import re
 
import urlparse
 
import textwrap
 

	
 
from pygments.formatters.html import HtmlFormatter
 
from pygments import highlight as code_highlight
 
from pylons import url
 
from pylons.i18n.translation import _, ungettext
 
from hashlib import md5
 

	
 
from webhelpers.html import literal, HTML, escape
 
from webhelpers.html.tools import *
 
from webhelpers.html.builder import make_tag
 
from webhelpers.html.tags import auto_discovery_link, checkbox, css_classes, \
 
    end_form, file, form, hidden, image, javascript_link, link_to, \
 
    link_to_if, link_to_unless, ol, required_legend, select, stylesheet_link, \
 
    submit, text, password, textarea, title, ul, xml_declaration, radio
 
from webhelpers.html.tools import auto_link, button_to, highlight, \
 
    js_obfuscate, mail_to, strip_links, strip_tags, tag_re
 
from webhelpers.number import format_byte_size, format_bit_size
 
from webhelpers.pylonslib import Flash as _Flash
 
from webhelpers.pylonslib.secure_form import secure_form
 
from webhelpers.text import chop_at, collapse, convert_accented_entities, \
 
    convert_misc_entities, lchop, plural, rchop, remove_formatting, \
 
    replace_whitespace, urlify, truncate, wrap_paragraphs
 
from webhelpers.date import time_ago_in_words
 
from webhelpers.paginate import Page as _Page
 
from webhelpers.html.tags import _set_input_attrs, _set_id_attr, \
 
    convert_boolean_attrs, NotGiven, _make_safe_id_component
 

	
 
from kallithea.lib.annotate import annotate_highlight
 
from kallithea.lib.utils import repo_name_slug, get_custom_lexer
 
from kallithea.lib.utils2 import str2bool, safe_unicode, safe_str, \
 
    get_changeset_safe, datetime_to_time, time_to_datetime, AttributeDict,\
 
    safe_int
 
from kallithea.lib.markup_renderer import MarkupRenderer, url_re
 
from kallithea.lib.vcs.exceptions import ChangesetDoesNotExistError
 
from kallithea.lib.vcs.backends.base import BaseChangeset, EmptyChangeset
 
from kallithea.config.conf import DATE_FORMAT, DATETIME_FORMAT
 
from kallithea.model.changeset_status import ChangesetStatusModel
 
from kallithea.model.db import URL_SEP, Permission
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
def canonical_url(*args, **kargs):
 
    '''Like url(x, qualified=True), but returns url that not only is qualified
 
    but also canonical, as configured in canonical_url'''
 
    from kallithea import CONFIG
 
    try:
 
        parts = CONFIG.get('canonical_url', '').split('://', 1)
 
        kargs['host'] = parts[1].split('/', 1)[0]
 
        kargs['protocol'] = parts[0]
 
    except IndexError:
 
        kargs['qualified'] = True
 
    return url(*args, **kargs)
 

	
 
def canonical_hostname():
 
    '''Return canonical hostname of system'''
 
    from kallithea import CONFIG
 
    try:
 
        parts = CONFIG.get('canonical_url', '').split('://', 1)
 
        return parts[1].split('/', 1)[0]
 
    except IndexError:
 
        parts = url('home', qualified=True).split('://', 1)
 
        return parts[1].split('/', 1)[0]
 

	
 
def html_escape(text, html_escape_table=None):
 
    """Produce entities within text."""
 
    if not html_escape_table:
 
        html_escape_table = {
 
            "&": "&amp;",
 
            '"': "&quot;",
 
            "'": "&apos;",
 
            ">": "&gt;",
 
            "<": "&lt;",
 
        }
 
    return "".join(html_escape_table.get(c, c) for c in text)
 

	
 

	
 
def shorter(text, size=20):
 
    postfix = '...'
 
    if len(text) > size:
 
        return text[:size - len(postfix)] + postfix
 
    return text
 

	
 

	
 
def _reset(name, value=None, id=NotGiven, type="reset", **attrs):
 
    """
 
    Reset button
 
    """
 
    _set_input_attrs(attrs, type, name, value)
 
    _set_id_attr(attrs, id, name)
 
    convert_boolean_attrs(attrs, ["disabled"])
 
    return HTML.input(**attrs)
 

	
 
reset = _reset
 
safeid = _make_safe_id_component
 

	
 

	
 
def FID(raw_id, path):
 
    """
 
    Creates a uniqe ID for filenode based on it's hash of path and revision
 
    it's safe to use in urls
 

	
 
    :param raw_id:
 
    :param path:
 
    """
 

	
 
    return 'C-%s-%s' % (short_id(raw_id), md5(safe_str(path)).hexdigest()[:12])
 

	
 

	
 
def get_token():
 
    """Return the current authentication token, creating one if one doesn't
 
    already exist.
 
    """
 
    token_key = "_authentication_token"
 
    from pylons import session
 
    if not token_key in session:
 
        try:
 
            token = hashlib.sha1(str(random.getrandbits(128))).hexdigest()
 
        except AttributeError:  # Python < 2.4
 
            token = hashlib.sha1(str(random.randrange(2 ** 128))).hexdigest()
 
        session[token_key] = token
 
        if hasattr(session, 'save'):
 
            session.save()
 
    return session[token_key]
 

	
 

	
 
class _GetError(object):
 
    """Get error from form_errors, and represent it as span wrapped error
 
    message
 

	
 
    :param field_name: field to fetch errors for
 
    :param form_errors: form errors dict
 
    """
 

	
 
    def __call__(self, field_name, form_errors):
 
        tmpl = """<span class="error_msg">%s</span>"""
 
        if form_errors and field_name in form_errors:
 
            return literal(tmpl % form_errors.get(field_name))
 

	
 
get_error = _GetError()
 

	
 

	
 
class _ToolTip(object):
 

	
 
    def __call__(self, tooltip_title, trim_at=50):
 
        """
 
        Special function just to wrap our text into nice formatted
 
        autowrapped text
 

	
 
        :param tooltip_title:
 
        """
 
        tooltip_title = escape(tooltip_title)
 
        tooltip_title = tooltip_title.replace('<', '&lt;').replace('>', '&gt;')
 
        return tooltip_title
 
tooltip = _ToolTip()
 

	
 

	
 
class _FilesBreadCrumbs(object):
 

	
 
    def __call__(self, repo_name, rev, paths):
 
        if isinstance(paths, str):
 
            paths = safe_unicode(paths)
 
        url_l = [link_to(repo_name, url('files_home',
 
                                        repo_name=repo_name,
 
                                        revision=rev, f_path=''),
 
                         class_='ypjax-link')]
 
        paths_l = paths.split('/')
 
        for cnt, p in enumerate(paths_l):
 
            if p != '':
 
                url_l.append(link_to(p,
 
                                     url('files_home',
 
                                         repo_name=repo_name,
 
                                         revision=rev,
 
                                         f_path='/'.join(paths_l[:cnt + 1])
 
                                         ),
 
                                     class_='ypjax-link'
 
                                     )
 
                             )
 

	
 
        return literal('/'.join(url_l))
 

	
 
files_breadcrumbs = _FilesBreadCrumbs()
 

	
 

	
 
class CodeHtmlFormatter(HtmlFormatter):
 
    """
 
    My code Html Formatter for source codes
 
    """
 

	
 
    def wrap(self, source, outfile):
 
        return self._wrap_div(self._wrap_pre(self._wrap_code(source)))
 

	
 
    def _wrap_code(self, source):
 
        for cnt, it in enumerate(source):
 
            i, t = it
 
            t = '<div id="L%s">%s</div>' % (cnt + 1, t)
 
            yield i, t
 

	
 
    def _wrap_tablelinenos(self, inner):
 
        dummyoutfile = StringIO.StringIO()
 
        lncount = 0
 
        for t, line in inner:
 
            if t:
 
                lncount += 1
 
            dummyoutfile.write(line)
 

	
 
        fl = self.linenostart
 
        mw = len(str(lncount + fl - 1))
 
        sp = self.linenospecial
 
        st = self.linenostep
 
        la = self.lineanchors
 
        aln = self.anchorlinenos
 
        nocls = self.noclasses
 
        if sp:
 
            lines = []
 

	
 
            for i in range(fl, fl + lncount):
 
                if i % st == 0:
 
                    if i % sp == 0:
 
                        if aln:
 
                            lines.append('<a href="#%s%d" class="special">%*d</a>' %
 
                                         (la, i, mw, i))
 
                        else:
 
                            lines.append('<span class="special">%*d</span>' % (mw, i))
 
                    else:
 
                        if aln:
 
                            lines.append('<a href="#%s%d">%*d</a>' % (la, i, mw, i))
 
                        else:
 
                            lines.append('%*d' % (mw, i))
 
                else:
 
                    lines.append('')
 
            ls = '\n'.join(lines)
 
        else:
 
            lines = []
 
            for i in range(fl, fl + lncount):
 
                if i % st == 0:
 
                    if aln:
 
                        lines.append('<a href="#%s%d">%*d</a>' % (la, i, mw, i))
 
                    else:
 
                        lines.append('%*d' % (mw, i))
 
                else:
 
                    lines.append('')
 
            ls = '\n'.join(lines)
 

	
 
        # in case you wonder about the seemingly redundant <div> here: since the
 
        # content in the other cell also is wrapped in a div, some browsers in
 
        # some configurations seem to mess up the formatting...
 
        if nocls:
 
            yield 0, ('<table class="%stable">' % self.cssclass +
 
                      '<tr><td><div class="linenodiv" '
 
                      'style="background-color: #f0f0f0; padding-right: 10px">'
 
                      '<pre style="line-height: 125%">' +
 
                      ls + '</pre></div></td><td id="hlcode" class="code">')
 
        else:
 
            yield 0, ('<table class="%stable">' % self.cssclass +
 
                      '<tr><td class="linenos"><div class="linenodiv"><pre>' +
 
                      ls + '</pre></div></td><td id="hlcode" class="code">')
 
        yield 0, dummyoutfile.getvalue()
 
        yield 0, '</td></tr></table>'
 

	
 

	
 
_whitespace_re = re.compile(r'(\t)|( )(?=\n|</div>)')
 

	
 
def _markup_whitespace(m):
 
    groups = m.groups()
 
    if groups[0]:
 
        return '<u>\t</u>'
 
    if groups[1]:
 
        return ' <i></i>'
 

	
 
def markup_whitespace(s):
 
    return _whitespace_re.sub(_markup_whitespace, s)
 

	
 
def pygmentize(filenode, **kwargs):
 
    """
 
    pygmentize function using pygments
 

	
 
    :param filenode:
 
    """
 
    lexer = get_custom_lexer(filenode.extension) or filenode.lexer
 
    return literal(markup_whitespace(
 
        code_highlight(filenode.content, lexer, CodeHtmlFormatter(**kwargs))))
 

	
 

	
 
def pygmentize_annotation(repo_name, filenode, **kwargs):
 
    """
 
    pygmentize function for annotation
 

	
 
    :param filenode:
 
    """
 

	
 
    color_dict = {}
 

	
 
    def gen_color(n=10000):
 
        """generator for getting n of evenly distributed colors using
 
        hsv color and golden ratio. It always return same order of colors
 

	
 
        :returns: RGB tuple
 
        """
 

	
 
        def hsv_to_rgb(h, s, v):
 
            if s == 0.0:
 
                return v, v, v
 
            i = int(h * 6.0)  # XXX assume int() truncates!
 
            f = (h * 6.0) - i
 
            p = v * (1.0 - s)
 
            q = v * (1.0 - s * f)
 
            t = v * (1.0 - s * (1.0 - f))
 
            i = i % 6
 
            if i == 0:
 
                return v, t, p
 
            if i == 1:
 
                return q, v, p
 
            if i == 2:
 
                return p, v, t
 
            if i == 3:
 
                return p, q, v
 
            if i == 4:
 
                return t, p, v
 
            if i == 5:
 
                return v, p, q
 

	
 
        golden_ratio = 0.618033988749895
 
        h = 0.22717784590367374
 

	
 
        for _unused in xrange(n):
 
            h += golden_ratio
 
            h %= 1
 
            HSV_tuple = [h, 0.95, 0.95]
 
            RGB_tuple = hsv_to_rgb(*HSV_tuple)
 
            yield map(lambda x: str(int(x * 256)), RGB_tuple)
 

	
 
    cgenerator = gen_color()
 

	
 
    def get_color_string(cs):
 
        if cs in color_dict:
 
            col = color_dict[cs]
 
        else:
 
            col = color_dict[cs] = cgenerator.next()
 
        return "color: rgb(%s)! important;" % (', '.join(col))
 

	
 
    def url_func(repo_name):
 

	
 
        def _url_func(changeset):
 
            author = changeset.author
 
            date = changeset.date
 
            message = tooltip(changeset.message)
 

	
 
            tooltip_html = ("<div style='font-size:0.8em'><b>Author:</b>"
 
                            " %s<br/><b>Date:</b> %s</b><br/><b>Message:"
 
                            "</b> %s<br/></div>")
 

	
 
            tooltip_html = tooltip_html % (author, date, message)
 
            lnk_format = show_id(changeset)
 
            uri = link_to(
 
                    lnk_format,
 
                    url('changeset_home', repo_name=repo_name,
 
                        revision=changeset.raw_id),
 
                    style=get_color_string(changeset.raw_id),
 
                    class_='tooltip',
 
                    title=tooltip_html
 
                  )
 

	
 
            uri += '\n'
 
            return uri
 
        return _url_func
 

	
 
    return literal(markup_whitespace(annotate_highlight(filenode, url_func(repo_name), **kwargs)))
 

	
 

	
 
def is_following_repo(repo_name, user_id):
 
    from kallithea.model.scm import ScmModel
 
    return ScmModel().is_following_repo(repo_name, user_id)
 

	
 
class _Message(object):
 
    """A message returned by ``Flash.pop_messages()``.
 

	
 
    Converting the message to a string returns the message text. Instances
 
    also have the following attributes:
 

	
 
    * ``message``: the message text.
 
    * ``category``: the category specified when the message was created.
 
    """
 

	
 
    def __init__(self, category, message):
 
        self.category = category
 
        self.message = message
 

	
 
    def __str__(self):
 
        return self.message
 

	
 
    __unicode__ = __str__
 

	
 
    def __html__(self):
 
        return escape(safe_unicode(self.message))
 

	
 
class Flash(_Flash):
 

	
 
    def pop_messages(self):
 
        """Return all accumulated messages and delete them from the session.
 

	
 
        The return value is a list of ``Message`` objects.
 
        """
 
        from pylons import session
 
        messages = session.pop(self.session_key, [])
 
        session.save()
 
        return [_Message(*m) for m in messages]
 

	
 
flash = Flash()
 

	
 
#==============================================================================
 
# SCM FILTERS available via h.
 
#==============================================================================
 
from kallithea.lib.vcs.utils import author_name, author_email
 
from kallithea.lib.utils2 import credentials_filter, age as _age
 
from kallithea.model.db import User, ChangesetStatus
 

	
 
age = lambda  x, y=False: _age(x, y)
 
capitalize = lambda x: x.capitalize()
 
email = author_email
 
short_id = lambda x: x[:12]
 
hide_credentials = lambda x: ''.join(credentials_filter(x))
 

	
 

	
 
def show_id(cs):
 
    """
 
    Configurable function that shows ID
 
    by default it's r123:fffeeefffeee
 

	
 
    :param cs: changeset instance
 
    """
 
    from kallithea import CONFIG
 
    def_len = safe_int(CONFIG.get('show_sha_length', 12))
 
    show_rev = str2bool(CONFIG.get('show_revision_number', False))
 

	
 
    raw_id = cs.raw_id[:def_len]
 
    if show_rev:
 
        return 'r%s:%s' % (cs.revision, raw_id)
 
    else:
 
        return '%s' % (raw_id)
 

	
 

	
 
def fmt_date(date):
 
    if date:
 
        _fmt = u"%a, %d %b %Y %H:%M:%S".encode('utf8')
 
        return date.strftime(_fmt).decode('utf8')
 

	
 
    return ""
 

	
 

	
 
def is_git(repository):
 
    if hasattr(repository, 'alias'):
 
        _type = repository.alias
 
    elif hasattr(repository, 'repo_type'):
 
        _type = repository.repo_type
 
    else:
 
        _type = repository
 
    return _type == 'git'
 

	
 

	
 
def is_hg(repository):
 
    if hasattr(repository, 'alias'):
 
        _type = repository.alias
 
    elif hasattr(repository, 'repo_type'):
 
        _type = repository.repo_type
 
    else:
 
        _type = repository
 
    return _type == 'hg'
 

	
 

	
 
def email_or_none(author):
 
    if not author:
 
        return None
 
    # extract email from the commit string
 
    _email = email(author)
 
    if _email:
 
        # check it against Kallithea database, and use the MAIN email for this
 
        # user
 
        user = User.get_by_email(_email, case_insensitive=True, cache=True)
 
        if user is not None:
 
            return user.email
 
        return _email
 

	
 
    # See if it contains a username we can get an email from
 
    user = User.get_by_username(author_name(author), case_insensitive=True,
 
                                cache=True)
 
    if user is not None:
 
        return user.email
 

	
 
    # No valid email, not a valid user in the system, none!
 
    return None
 

	
 

	
 
def person(author, show_attr="username"):
 
    # attr to return from fetched user
 
    person_getter = lambda usr: getattr(usr, show_attr)
 

	
 
    # if author is already an instance use it for extraction
 
    if isinstance(author, User):
 
        return person_getter(author)
 

	
 
    # Valid email in the attribute passed, see if they're in the system
 
    _email = email(author)
 
    if _email:
 
        user = User.get_by_email(_email, case_insensitive=True, cache=True)
 
        if user is not None:
 
            return person_getter(user)
 

	
 
    # Maybe it's a username?
 
    _author = author_name(author)
 
    user = User.get_by_username(_author, case_insensitive=True,
 
                                cache=True)
 
    if user is not None:
 
        return person_getter(user)
 

	
 
    # Still nothing?  Just pass back the author name if any, else the email
 
    return _author or _email
 

	
 

	
 
def person_by_id(id_, show_attr="username"):
 
    # attr to return from fetched user
 
    person_getter = lambda usr: getattr(usr, show_attr)
 

	
 
    #maybe it's an ID ?
 
    if str(id_).isdigit() or isinstance(id_, int):
 
        id_ = int(id_)
 
        user = User.get(id_)
 
        if user is not None:
 
            return person_getter(user)
 
    return id_
 

	
 

	
 
def desc_stylize(value):
 
    """
 
    converts tags from value into html equivalent
 

	
 
    :param value:
 
    """
 
    if not value:
 
        return ''
 

	
 
    value = re.sub(r'\[see\ \=\>\ *([a-zA-Z0-9\/\=\?\&\ \:\/\.\-]*)\]',
 
                   '<div class="metatag" tag="see">see =&gt; \\1 </div>', value)
 
    value = re.sub(r'\[license\ \=\>\ *([a-zA-Z0-9\/\=\?\&\ \:\/\.\-]*)\]',
 
                   '<div class="metatag" tag="license"><a href="http:\/\/www.opensource.org/licenses/\\1">\\1</a></div>', value)
 
    value = re.sub(r'\[(requires|recommends|conflicts|base)\ \=\>\ *([a-zA-Z0-9\-\/]*)\]',
 
                   '<div class="metatag" tag="\\1">\\1 =&gt; <a href="/\\2">\\2</a></div>', value)
 
    value = re.sub(r'\[(lang|language)\ \=\>\ *([a-zA-Z\-\/\#\+]*)\]',
 
                   '<div class="metatag" tag="lang">\\2</div>', value)
 
    value = re.sub(r'\[([a-z]+)\]',
 
                  '<div class="metatag" tag="\\1">\\1</div>', value)
 

	
 
    return value
 

	
 

	
 
def boolicon(value):
 
    """Returns boolean value of a value, represented as small html image of true/false
 
    icons
 

	
 
    :param value: value
 
    """
 

	
 
    if value:
 
        return HTML.tag('i', class_="icon-ok")
 
    else:
 
        return HTML.tag('i', class_="icon-minus-circled")
 

	
 

	
 
def action_parser(user_log, feed=False, parse_cs=False):
 
    """
 
    This helper will action_map the specified string action into translated
 
    fancy names with icons and links
 

	
 
    :param user_log: user log instance
 
    :param feed: use output for feeds (no html and fancy icons)
 
    :param parse_cs: parse Changesets into VCS instances
 
    """
 

	
 
    action = user_log.action
 
    action_params = ' '
 

	
 
    x = action.split(':')
 

	
 
    if len(x) > 1:
 
        action, action_params = x
 

	
 
    def get_cs_links():
 
        revs_limit = 3  # display this amount always
 
        revs_top_limit = 50  # show upto this amount of changesets hidden
 
        revs_ids = action_params.split(',')
 
        deleted = user_log.repository is None
 
        if deleted:
 
            return ','.join(revs_ids)
 

	
 
        repo_name = user_log.repository.repo_name
 

	
 
        def lnk(rev, repo_name):
 
            if isinstance(rev, BaseChangeset) or isinstance(rev, AttributeDict):
 
                lazy_cs = True
 
                if getattr(rev, 'op', None) and getattr(rev, 'ref_name', None):
 
                    lazy_cs = False
 
                    lbl = '?'
 
                    if rev.op == 'delete_branch':
 
                        lbl = '%s' % _('Deleted branch: %s') % rev.ref_name
 
                        title = ''
 
                    elif rev.op == 'tag':
 
                        lbl = '%s' % _('Created tag: %s') % rev.ref_name
 
                        title = ''
 
                    _url = '#'
 

	
 
                else:
 
                    lbl = '%s' % (rev.short_id[:8])
 
                    _url = url('changeset_home', repo_name=repo_name,
 
                               revision=rev.raw_id)
 
                    title = tooltip(rev.message)
 
            else:
 
                ## changeset cannot be found/striped/removed etc.
 
                lbl = ('%s' % rev)[:12]
 
                _url = '#'
 
                title = _('Changeset not found')
 
            if parse_cs:
 
                return link_to(lbl, _url, title=title, class_='tooltip')
 
            return link_to(lbl, _url, raw_id=rev.raw_id, repo_name=repo_name,
 
                           class_='lazy-cs' if lazy_cs else '')
 

	
 
        def _get_op(rev_txt):
 
            _op = None
 
            _name = rev_txt
 
            if len(rev_txt.split('=>')) == 2:
 
                _op, _name = rev_txt.split('=>')
 
            return _op, _name
 

	
 
        revs = []
 
        if len(filter(lambda v: v != '', revs_ids)) > 0:
 
            repo = None
 
            for rev in revs_ids[:revs_top_limit]:
 
                _op, _name = _get_op(rev)
 

	
 
                # we want parsed changesets, or new log store format is bad
 
                if parse_cs:
 
                    try:
 
                        if repo is None:
 
                            repo = user_log.repository.scm_instance
 
                        _rev = repo.get_changeset(rev)
 
                        revs.append(_rev)
 
                    except ChangesetDoesNotExistError:
 
                        log.error('cannot find revision %s in this repo' % rev)
 
                        revs.append(rev)
 
                        continue
 
                else:
 
                    _rev = AttributeDict({
 
                        'short_id': rev[:12],
 
                        'raw_id': rev,
 
                        'message': '',
 
                        'op': _op,
 
                        'ref_name': _name
 
                    })
 
                    revs.append(_rev)
 
        cs_links = [" " + ', '.join(
 
            [lnk(rev, repo_name) for rev in revs[:revs_limit]]
 
        )]
 
        _op1, _name1 = _get_op(revs_ids[0])
 
        _op2, _name2 = _get_op(revs_ids[-1])
 

	
 
        _rev = '%s...%s' % (_name1, _name2)
 

	
 
        compare_view = (
 
            ' <div class="compare_view tooltip" title="%s">'
 
            '<a href="%s">%s</a> </div>' % (
 
                _('Show all combined changesets %s->%s') % (
 
                    revs_ids[0][:12], revs_ids[-1][:12]
 
                ),
 
                url('changeset_home', repo_name=repo_name,
 
                    revision=_rev
 
                ),
 
                _('compare view')
 
            )
 
        )
 

	
 
        # if we have exactly one more than normally displayed
 
        # just display it, takes less space than displaying
 
        # "and 1 more revisions"
 
        if len(revs_ids) == revs_limit + 1:
 
            cs_links.append(", " + lnk(revs[revs_limit], repo_name))
 

	
 
        # hidden-by-default ones
 
        if len(revs_ids) > revs_limit + 1:
 
            uniq_id = revs_ids[0]
 
            html_tmpl = (
 
                '<span> %s <a class="show_more" id="_%s" '
 
                'href="#more">%s</a> %s</span>'
 
            )
 
            if not feed:
 
                cs_links.append(html_tmpl % (
 
                      _('and'),
 
                      uniq_id, _('%s more') % (len(revs_ids) - revs_limit),
 
                      _('revisions')
 
                    )
 
                )
 

	
 
            if not feed:
 
                html_tmpl = '<span id="%s" style="display:none">, %s </span>'
 
            else:
 
                html_tmpl = '<span id="%s"> %s </span>'
 

	
 
            morelinks = ', '.join(
 
              [lnk(rev, repo_name) for rev in revs[revs_limit:]]
 
            )
 

	
 
            if len(revs_ids) > revs_top_limit:
 
                morelinks += ', ...'
 

	
 
            cs_links.append(html_tmpl % (uniq_id, morelinks))
 
        if len(revs) > 1:
 
            cs_links.append(compare_view)
 
        return ''.join(cs_links)
 

	
 
    def get_fork_name():
 
        repo_name = action_params
 
        _url = url('summary_home', repo_name=repo_name)
 
        return _('fork name %s') % link_to(action_params, _url)
 

	
 
    def get_user_name():
 
        user_name = action_params
 
        return user_name
 

	
 
    def get_users_group():
 
        group_name = action_params
 
        return group_name
 

	
 
    def get_pull_request():
 
        pull_request_id = action_params
 
        deleted = user_log.repository is None
 
        if deleted:
 
            repo_name = user_log.repository_name
 
        else:
 
            repo_name = user_log.repository.repo_name
 
        return link_to(_('Pull request #%s') % pull_request_id,
 
                    url('pullrequest_show', repo_name=repo_name,
 
                    pull_request_id=pull_request_id))
 

	
 
    def get_archive_name():
 
        archive_name = action_params
 
        return archive_name
 

	
 
    # action : translated str, callback(extractor), icon
 
    action_map = {
 
    'user_deleted_repo':           (_('[deleted] repository'),
 
                                    None, 'icon-trashcan'),
 
    'user_created_repo':           (_('[created] repository'),
 
                                    None, 'icon-plus icon-plus-colored'),
 
    'user_created_fork':           (_('[created] repository as fork'),
 
                                    None, 'icon-fork'),
 
    'user_forked_repo':            (_('[forked] repository'),
 
                                    get_fork_name, 'icon-fork'),
 
    'user_updated_repo':           (_('[updated] repository'),
 
                                    None, 'icon-pencil icon-pencil-colored'),
 
    'user_downloaded_archive':      (_('[downloaded] archive from repository'),
 
                                    get_archive_name, 'icon-download-cloud'),
 
    'admin_deleted_repo':          (_('[delete] repository'),
 
                                    None, 'icon-trashcan'),
 
    'admin_created_repo':          (_('[created] repository'),
 
                                    None, 'icon-plus icon-plus-colored'),
 
    'admin_forked_repo':           (_('[forked] repository'),
 
                                    None, 'icon-fork icon-fork-colored'),
 
    'admin_updated_repo':          (_('[updated] repository'),
 
                                    None, 'icon-pencil icon-pencil-colored'),
 
    'admin_created_user':          (_('[created] user'),
 
                                    get_user_name, 'icon-user icon-user-colored'),
 
    'admin_updated_user':          (_('[updated] user'),
 
                                    get_user_name, 'icon-user icon-user-colored'),
 
    'admin_created_users_group':   (_('[created] user group'),
 
                                    get_users_group, 'icon-pencil icon-pencil-colored'),
 
    'admin_updated_users_group':   (_('[updated] user group'),
 
                                    get_users_group, 'icon-pencil icon-pencil-colored'),
 
    'user_commented_revision':     (_('[commented] on revision in repository'),
 
                                    get_cs_links, 'icon-comment icon-comment-colored'),
 
    'user_commented_pull_request': (_('[commented] on pull request for'),
 
                                    get_pull_request, 'icon-comment icon-comment-colored'),
 
    'user_closed_pull_request':    (_('[closed] pull request for'),
 
                                    get_pull_request, 'icon-ok'),
 
    'push':                        (_('[pushed] into'),
 
                                    get_cs_links, 'icon-move-up'),
 
    'push_local':                  (_('[committed via Kallithea] into repository'),
 
                                    get_cs_links, 'icon-pencil icon-pencil-colored'),
 
    'push_remote':                 (_('[pulled from remote] into repository'),
 
                                    get_cs_links, 'icon-move-up'),
 
    'pull':                        (_('[pulled] from'),
 
                                    None, 'icon-move-down'),
 
    'started_following_repo':      (_('[started following] repository'),
 
                                    None, 'icon-heart icon-heart-colored'),
 
    'stopped_following_repo':      (_('[stopped following] repository'),
 
                                    None, 'icon-heart-empty icon-heart-colored'),
 
    }
 

	
 
    action_str = action_map.get(action, action)
 
    if feed:
 
        action = action_str[0].replace('[', '').replace(']', '')
 
    else:
 
        action = action_str[0]\
 
            .replace('[', '<span class="journal_highlight">')\
 
            .replace(']', '</span>')
 

	
 
    action_params_func = lambda: ""
 

	
 
    if callable(action_str[1]):
 
        action_params_func = action_str[1]
 

	
 
    def action_parser_icon():
 
        action = user_log.action
 
        action_params = None
 
        x = action.split(':')
 

	
 
        if len(x) > 1:
 
            action, action_params = x
 

	
 
        tmpl = """<i class="%s" alt="%s"></i>"""
 
        ico = action_map.get(action, ['', '', ''])[2]
 
        return literal(tmpl % (ico, action))
 

	
 
    # returned callbacks we need to call to get
 
    return [lambda: literal(action), action_params_func, action_parser_icon]
 

	
 

	
 

	
 
#==============================================================================
 
# PERMS
 
#==============================================================================
 
from kallithea.lib.auth import HasPermissionAny, HasPermissionAll, \
 
HasRepoPermissionAny, HasRepoPermissionAll, HasRepoGroupPermissionAll, \
 
HasRepoGroupPermissionAny
 

	
 

	
 
#==============================================================================
 
# GRAVATAR URL
 
#==============================================================================
 

	
 
def gravatar_url(email_address, size=30, ssl_enabled=True):
 
    # doh, we need to re-import those to mock it later
 
    from pylons import url
 
    from pylons import tmpl_context as c
 

	
 
    _def = 'anonymous@kallithea-scm.org'  # default gravatar
 
    _use_gravatar = c.visual.use_gravatar
 
    _gravatar_url = c.visual.gravatar_url or User.DEFAULT_GRAVATAR_URL
 

	
 
    email_address = email_address or _def
 
    if isinstance(email_address, unicode):
 
        #hashlib crashes on unicode items
 
        email_address = safe_str(email_address)
 

	
 
    if not _use_gravatar or not email_address or email_address == _def:
 
        # pick best matching size to one given in size param
 
        f = lambda a, l: min(l, key=lambda x: abs(x - a))
 
        return url("/images/user%s.png" % f(size, [14, 16, 20, 24, 30]))
 

	
 
    if _use_gravatar:
 
        _md5 = lambda s: hashlib.md5(s).hexdigest()
 

	
 
        tmpl = _gravatar_url
 
        parsed_url = urlparse.urlparse(url.current(qualified=True))
 
        tmpl = tmpl.replace('{email}', email_address)\
 
                   .replace('{md5email}', _md5(email_address.lower())) \
 
                   .replace('{netloc}', parsed_url.netloc)\
 
                   .replace('{scheme}', parsed_url.scheme)\
 
                   .replace('{size}', safe_str(size))
 
        return tmpl
 

	
 
class Page(_Page):
 
    """
 
    Custom pager to match rendering style with YUI paginator
 
    """
 

	
 
    def _get_pos(self, cur_page, max_page, items):
 
        edge = (items / 2) + 1
 
        if (cur_page <= edge):
 
            radius = max(items / 2, items - cur_page)
 
        elif (max_page - cur_page) < edge:
 
            radius = (items - 1) - (max_page - cur_page)
 
        else:
 
            radius = items / 2
 

	
 
        left = max(1, (cur_page - (radius)))
 
        right = min(max_page, cur_page + (radius))
 
        return left, cur_page, right
 

	
 
    def _range(self, regexp_match):
 
        """
 
        Return range of linked pages (e.g. '1 2 [3] 4 5 6 7 8').
 

	
 
        Arguments:
 

	
 
        regexp_match
 
            A "re" (regular expressions) match object containing the
 
            radius of linked pages around the current page in
 
            regexp_match.group(1) as a string
 

	
 
        This function is supposed to be called as a callable in
 
        re.sub.
 

	
 
        """
 
        radius = int(regexp_match.group(1))
 

	
 
        # Compute the first and last page number within the radius
 
        # e.g. '1 .. 5 6 [7] 8 9 .. 12'
 
        # -> leftmost_page  = 5
 
        # -> rightmost_page = 9
 
        leftmost_page, _cur, rightmost_page = self._get_pos(self.page,
 
                                                            self.last_page,
 
                                                            (radius * 2) + 1)
 
        nav_items = []
 

	
 
        # Create a link to the first page (unless we are on the first page
 
        # or there would be no need to insert '..' spacers)
 
        if self.page != self.first_page and self.first_page < leftmost_page:
 
            nav_items.append(self._pagerlink(self.first_page, self.first_page))
 

	
 
        # Insert dots if there are pages between the first page
 
        # and the currently displayed page range
 
        if leftmost_page - self.first_page > 1:
 
            # Wrap in a SPAN tag if nolink_attr is set
 
            text = '..'
 
            if self.dotdot_attr:
 
                text = HTML.span(c=text, **self.dotdot_attr)
 
            nav_items.append(text)
 

	
 
        for thispage in xrange(leftmost_page, rightmost_page + 1):
 
            # Hilight the current page number and do not use a link
 
            if thispage == self.page:
 
                text = '%s' % (thispage,)
 
                # Wrap in a SPAN tag if nolink_attr is set
 
                if self.curpage_attr:
 
                    text = HTML.span(c=text, **self.curpage_attr)
 
                nav_items.append(text)
 
            # Otherwise create just a link to that page
 
            else:
 
                text = '%s' % (thispage,)
 
                nav_items.append(self._pagerlink(thispage, text))
 

	
 
        # Insert dots if there are pages between the displayed
 
        # page numbers and the end of the page range
 
        if self.last_page - rightmost_page > 1:
 
            text = '..'
 
            # Wrap in a SPAN tag if nolink_attr is set
 
            if self.dotdot_attr:
 
                text = HTML.span(c=text, **self.dotdot_attr)
 
            nav_items.append(text)
 

	
 
        # Create a link to the very last page (unless we are on the last
 
        # page or there would be no need to insert '..' spacers)
 
        if self.page != self.last_page and rightmost_page < self.last_page:
 
            nav_items.append(self._pagerlink(self.last_page, self.last_page))
 

	
 
        #_page_link = url.current()
 
        #nav_items.append(literal('<link rel="prerender" href="%s?page=%s">' % (_page_link, str(int(self.page)+1))))
 
        #nav_items.append(literal('<link rel="prefetch" href="%s?page=%s">' % (_page_link, str(int(self.page)+1))))
 
        return self.separator.join(nav_items)
 

	
 
    def pager(self, format='~2~', page_param='page', partial_param='partial',
 
        show_if_single_page=False, separator=' ', onclick=None,
 
        symbol_first='<<', symbol_last='>>',
 
        symbol_previous='<', symbol_next='>',
 
        link_attr={'class': 'pager_link', 'rel': 'prerender'},
 
        curpage_attr={'class': 'pager_curpage'},
 
        dotdot_attr={'class': 'pager_dotdot'}, **kwargs):
 

	
 
        self.curpage_attr = curpage_attr
 
        self.separator = separator
 
        self.pager_kwargs = kwargs
 
        self.page_param = page_param
 
        self.partial_param = partial_param
 
        self.onclick = onclick
 
        self.link_attr = link_attr
 
        self.dotdot_attr = dotdot_attr
 

	
 
        # Don't show navigator if there is no more than one page
 
        if self.page_count == 0 or (self.page_count == 1 and not show_if_single_page):
 
            return ''
 

	
 
        from string import Template
 
        # Replace ~...~ in token format by range of pages
 
        result = re.sub(r'~(\d+)~', self._range, format)
 

	
 
        # Interpolate '%' variables
 
        result = Template(result).safe_substitute({
 
            'first_page': self.first_page,
 
            'last_page': self.last_page,
 
            'page': self.page,
 
            'page_count': self.page_count,
 
            'items_per_page': self.items_per_page,
 
            'first_item': self.first_item,
 
            'last_item': self.last_item,
 
            'item_count': self.item_count,
 
            'link_first': self.page > self.first_page and \
 
                    self._pagerlink(self.first_page, symbol_first) or '',
 
            'link_last': self.page < self.last_page and \
 
                    self._pagerlink(self.last_page, symbol_last) or '',
 
            'link_previous': self.previous_page and \
 
                    self._pagerlink(self.previous_page, symbol_previous) \
 
                    or HTML.span(symbol_previous, class_="yui-pg-previous"),
 
            'link_next': self.next_page and \
 
                    self._pagerlink(self.next_page, symbol_next) \
 
                    or HTML.span(symbol_next, class_="yui-pg-next")
 
        })
 

	
 
        return literal(result)
 

	
 

	
 
#==============================================================================
 
# REPO PAGER, PAGER FOR REPOSITORY
 
#==============================================================================
 
class RepoPage(Page):
 

	
 
    def __init__(self, collection, page=1, items_per_page=20,
 
                 item_count=None, url=None, **kwargs):
 

	
 
        """Create a "RepoPage" instance. special pager for paging
 
        repository
 
        """
 
        self._url_generator = url
 

	
 
        # Safe the kwargs class-wide so they can be used in the pager() method
 
        self.kwargs = kwargs
 

	
 
        # Save a reference to the collection
 
        self.original_collection = collection
 

	
 
        self.collection = collection
 

	
 
        # The self.page is the number of the current page.
 
        # The first page has the number 1!
 
        try:
 
            self.page = int(page)  # make it int() if we get it as a string
 
        except (ValueError, TypeError):
 
            self.page = 1
 

	
 
        self.items_per_page = items_per_page
 

	
 
        # Unless the user tells us how many items the collections has
 
        # we calculate that ourselves.
 
        if item_count is not None:
 
            self.item_count = item_count
 
        else:
 
            self.item_count = len(self.collection)
 

	
 
        # Compute the number of the first and last available page
 
        if self.item_count > 0:
 
            self.first_page = 1
 
            self.page_count = int(math.ceil(float(self.item_count) /
 
                                            self.items_per_page))
 
            self.last_page = self.first_page + self.page_count - 1
 

	
 
            # Make sure that the requested page number is the range of
 
            # valid pages
 
            if self.page > self.last_page:
 
                self.page = self.last_page
 
            elif self.page < self.first_page:
 
                self.page = self.first_page
 

	
 
            # Note: the number of items on this page can be less than
 
            #       items_per_page if the last page is not full
 
            self.first_item = max(0, (self.item_count) - (self.page *
 
                                                          items_per_page))
 
            self.last_item = ((self.item_count - 1) - items_per_page *
 
                              (self.page - 1))
 

	
 
            self.items = list(self.collection[self.first_item:self.last_item + 1])
 

	
 
            # Links to previous and next page
 
            if self.page > self.first_page:
 
                self.previous_page = self.page - 1
 
            else:
 
                self.previous_page = None
 

	
 
            if self.page < self.last_page:
 
                self.next_page = self.page + 1
 
            else:
 
                self.next_page = None
 

	
 
        # No items available
 
        else:
 
            self.first_page = None
 
            self.page_count = 0
 
            self.last_page = None
 
            self.first_item = None
 
            self.last_item = None
 
            self.previous_page = None
 
            self.next_page = None
 
            self.items = []
 

	
 
        # This is a subclass of the 'list' type. Initialise the list now.
 
        list.__init__(self, reversed(self.items))
 

	
 

	
 
def changed_tooltip(nodes):
 
    """
 
    Generates a html string for changed nodes in changeset page.
 
    It limits the output to 30 entries
 

	
 
    :param nodes: LazyNodesGenerator
 
    """
 
    if nodes:
 
        pref = ': <br/> '
 
        suf = ''
 
        if len(nodes) > 30:
 
            suf = '<br/>' + _(' and %s more') % (len(nodes) - 30)
 
        return literal(pref + '<br/> '.join([safe_unicode(x.path)
 
                                             for x in nodes[:30]]) + suf)
 
    else:
 
        return ': ' + _('No Files')
 

	
 

	
 
def repo_link(groups_and_repos):
 
    """
 
    Makes a breadcrumbs link to repo within a group
 
    joins &raquo; on each group to create a fancy link
 

	
 
    ex::
 
        group >> subgroup >> repo
 

	
 
    :param groups_and_repos:
 
    :param last_url:
 
    """
 
    groups, just_name, repo_name = groups_and_repos
 
    last_url = url('summary_home', repo_name=repo_name)
 
    last_link = link_to(just_name, last_url)
 

	
 
    def make_link(group):
 
        return link_to(group.name,
 
                       url('repos_group_home', group_name=group.group_name))
 
    return literal(' &raquo; '.join(map(make_link, groups) + ['<span>%s</span>' % last_link]))
 

	
 

	
 
def fancy_file_stats(stats):
 
    """
 
    Displays a fancy two colored bar for number of added/deleted
 
    lines of code on file
 

	
 
    :param stats: two element list of added/deleted lines of code
 
    """
 
    from kallithea.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
 
        MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE
 

	
 
    def cgen(l_type, a_v, d_v):
 
        mapping = {'tr': 'top-right-rounded-corner-mid',
 
                   'tl': 'top-left-rounded-corner-mid',
 
                   'br': 'bottom-right-rounded-corner-mid',
 
                   'bl': 'bottom-left-rounded-corner-mid'}
 
        map_getter = lambda x: mapping[x]
 

	
 
        if l_type == 'a' and d_v:
 
            #case when added and deleted are present
 
            return ' '.join(map(map_getter, ['tl', 'bl']))
 

	
 
        if l_type == 'a' and not d_v:
 
            return ' '.join(map(map_getter, ['tr', 'br', 'tl', 'bl']))
 

	
 
        if l_type == 'd' and a_v:
 
            return ' '.join(map(map_getter, ['tr', 'br']))
 

	
 
        if l_type == 'd' and not a_v:
 
            return ' '.join(map(map_getter, ['tr', 'br', 'tl', 'bl']))
 

	
 
    a, d = stats['added'], stats['deleted']
 
    width = 100
 

	
 
    if stats['binary']:
 
        #binary mode
 
        lbl = ''
 
        bin_op = 1
 

	
 
        if BIN_FILENODE in stats['ops']:
 
            lbl = 'bin+'
 

	
 
        if NEW_FILENODE in stats['ops']:
 
            lbl += _('new file')
 
            bin_op = NEW_FILENODE
 
        elif MOD_FILENODE in stats['ops']:
 
            lbl += _('mod')
 
            bin_op = MOD_FILENODE
 
        elif DEL_FILENODE in stats['ops']:
 
            lbl += _('del')
 
            bin_op = DEL_FILENODE
 
        elif RENAMED_FILENODE in stats['ops']:
 
            lbl += _('rename')
 
            bin_op = RENAMED_FILENODE
 

	
 
        #chmod can go with other operations
 
        if CHMOD_FILENODE in stats['ops']:
 
            _org_lbl = _('chmod')
 
            lbl += _org_lbl if lbl.endswith('+') else '+%s' % _org_lbl
 

	
 
        #import ipdb;ipdb.set_trace()
 
        b_d = '<div class="bin bin%s %s" style="width:100%%">%s</div>' % (bin_op, cgen('a', a_v='', d_v=0), lbl)
 
        b_a = '<div class="bin bin1" style="width:0%%"></div>'
 
        return literal('<div style="width:%spx">%s%s</div>' % (width, b_a, b_d))
 

	
 
    t = stats['added'] + stats['deleted']
 
    unit = float(width) / (t or 1)
 

	
 
    # needs > 9% of width to be visible or 0 to be hidden
 
    a_p = max(9, unit * a) if a > 0 else 0
 
    d_p = max(9, unit * d) if d > 0 else 0
 
    p_sum = a_p + d_p
 

	
 
    if p_sum > width:
 
        #adjust the percentage to be == 100% since we adjusted to 9
 
        if a_p > d_p:
 
            a_p = a_p - (p_sum - width)
 
        else:
 
            d_p = d_p - (p_sum - width)
 

	
 
    a_v = a if a > 0 else ''
 
    d_v = d if d > 0 else ''
 

	
 
    d_a = '<div class="added %s" style="width:%s%%">%s</div>' % (
 
        cgen('a', a_v, d_v), a_p, a_v
 
    )
 
    d_d = '<div class="deleted %s" style="width:%s%%">%s</div>' % (
 
        cgen('d', a_v, d_v), d_p, d_v
 
    )
 
    return literal('<div style="width:%spx">%s%s</div>' % (width, d_a, d_d))
 

	
 

	
 
def urlify_text(text_, safe=True):
 
    """
 
    Extrac urls from text and make html links out of them
 

	
 
    :param text_:
 
    """
 

	
 
    def url_func(match_obj):
 
        url_full = match_obj.groups()[0]
 
        return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
 
    _newtext = url_re.sub(url_func, text_)
 
    if safe:
 
        return literal(_newtext)
 
    return _newtext
 

	
 

	
 
def urlify_changesets(text_, repository):
 
    """
 
    Extract revision ids from changeset and make link from them
 

	
 
    :param text_:
 
    :param repository: repo name to build the URL with
 
    """
 
    from pylons import url  # doh, we need to re-import url to mock it later
 

	
 
    def url_func(match_obj):
 
        rev = match_obj.group(0)
 
        return '<a class="revision-link" href="%(url)s">%(rev)s</a>' % {
 
         'url': url('changeset_home', repo_name=repository, revision=rev),
 
         'rev': rev,
 
        }
 

	
 
    return re.sub(r'(?:^|(?<=[\s(),]))([0-9a-fA-F]{12,40})(?=$|\s|[.,:()])', url_func, text_)
 

	
 
def linkify_others(t, l):
 
    urls = re.compile(r'(\<a.*?\<\/a\>)',)
 
    links = []
 
    for e in urls.split(t):
 
        if not urls.match(e):
 
            links.append('<a class="message-link" href="%s">%s</a>' % (l, e))
 
        else:
 
            links.append(e)
 

	
 
    return ''.join(links)
 

	
 
def urlify_commit(text_, repository, link_=None):
 
    """
 
    Parses given text message and makes proper links.
 
    issues are linked to given issue-server, and rest is a changeset link
 
    if link_ is given, in other case it's a plain text
 

	
 
    :param text_:
 
    :param repository:
 
    :param link_: changeset link
 
    """
 
    def escaper(string):
 
        return string.replace('<', '&lt;').replace('>', '&gt;')
 

	
 
    # urlify changesets - extract revisions and make link out of them
 
    newtext = urlify_changesets(escaper(text_), repository)
 

	
 
    # extract http/https links and make them real urls
 
    newtext = urlify_text(newtext, safe=False)
 

	
 
    newtext = urlify_issues(newtext, repository, link_)
 

	
 
    return literal(newtext)
 

	
 
def urlify_issues(newtext, repository, link_=None):
 
    try:
 
        import traceback
 
        from kallithea import CONFIG
 
        conf = CONFIG
 
    from kallithea import CONFIG as conf
 

	
 
        # allow multiple issue servers to be used
 
        valid_indices = [
 
            x.group(1)
 
            for x in map(lambda x: re.match(r'issue_pat(.*)', x), conf.keys())
 
            if x and 'issue_server_link%s' % x.group(1) in conf
 
            and 'issue_prefix%s' % x.group(1) in conf
 
        ]
 
    # allow multiple issue servers to be used
 
    valid_indices = [
 
        x.group(1)
 
        for x in map(lambda x: re.match(r'issue_pat(.*)', x), conf.keys())
 
        if x and 'issue_server_link%s' % x.group(1) in conf
 
        and 'issue_prefix%s' % x.group(1) in conf
 
    ]
 

	
 
        if valid_indices:
 
            log.debug('found issue server suffixes `%s` during valuation of: %s'
 
                      % (','.join(valid_indices), newtext))
 
    if valid_indices:
 
        log.debug('found issue server suffixes `%s` during valuation of: %s'
 
                  % (','.join(valid_indices), newtext))
 

	
 
        for pattern_index in valid_indices:
 
            ISSUE_PATTERN = conf.get('issue_pat%s' % pattern_index)
 
            ISSUE_SERVER_LNK = conf.get('issue_server_link%s' % pattern_index)
 
            ISSUE_PREFIX = conf.get('issue_prefix%s' % pattern_index)
 
    for pattern_index in valid_indices:
 
        ISSUE_PATTERN = conf.get('issue_pat%s' % pattern_index)
 
        ISSUE_SERVER_LNK = conf.get('issue_server_link%s' % pattern_index)
 
        ISSUE_PREFIX = conf.get('issue_prefix%s' % pattern_index)
 

	
 
            log.debug('pattern suffix `%s` PAT:%s SERVER_LINK:%s PREFIX:%s'
 
                      % (pattern_index, ISSUE_PATTERN, ISSUE_SERVER_LNK,
 
                         ISSUE_PREFIX))
 
        log.debug('pattern suffix `%s` PAT:%s SERVER_LINK:%s PREFIX:%s'
 
                  % (pattern_index, ISSUE_PATTERN, ISSUE_SERVER_LNK,
 
                     ISSUE_PREFIX))
 

	
 
            URL_PAT = re.compile(r'%s' % ISSUE_PATTERN)
 
        URL_PAT = re.compile(r'%s' % ISSUE_PATTERN)
 

	
 
            def url_func(match_obj):
 
                pref = ''
 
                if match_obj.group().startswith(' '):
 
                    pref = ' '
 
        def url_func(match_obj):
 
            pref = ''
 
            if match_obj.group().startswith(' '):
 
                pref = ' '
 

	
 
                issue_id = ''.join(match_obj.groups())
 
                issue_url = ISSUE_SERVER_LNK.replace('{id}', issue_id)
 
                if repository:
 
                    issue_url = issue_url.replace('{repo}', repository)
 
                    repo_name = repository.split(URL_SEP)[-1]
 
                    issue_url = issue_url.replace('{repo_name}', repo_name)
 
            issue_id = ''.join(match_obj.groups())
 
            issue_url = ISSUE_SERVER_LNK.replace('{id}', issue_id)
 
            if repository:
 
                issue_url = issue_url.replace('{repo}', repository)
 
                repo_name = repository.split(URL_SEP)[-1]
 
                issue_url = issue_url.replace('{repo_name}', repo_name)
 

	
 
                return (
 
                    '%(pref)s<a class="%(cls)s" href="%(url)s">'
 
                    '%(issue-prefix)s%(id-repr)s'
 
                    '</a>'
 
                    ) % {
 
                     'pref': pref,
 
                     'cls': 'issue-tracker-link',
 
                     'url': issue_url,
 
                     'id-repr': issue_id,
 
                     'issue-prefix': ISSUE_PREFIX,
 
                     'serv': ISSUE_SERVER_LNK,
 
                    }
 
            newtext = URL_PAT.sub(url_func, newtext)
 
            log.debug('processed prefix:`%s` => %s' % (pattern_index, newtext))
 
            return (
 
                '%(pref)s<a class="%(cls)s" href="%(url)s">'
 
                '%(issue-prefix)s%(id-repr)s'
 
                '</a>'
 
                ) % {
 
                 'pref': pref,
 
                 'cls': 'issue-tracker-link',
 
                 'url': issue_url,
 
                 'id-repr': issue_id,
 
                 'issue-prefix': ISSUE_PREFIX,
 
                 'serv': ISSUE_SERVER_LNK,
 
                }
 
        newtext = URL_PAT.sub(url_func, newtext)
 
        log.debug('processed prefix:`%s` => %s' % (pattern_index, newtext))
 

	
 
        # if we actually did something above
 
        if link_:
 
            # wrap not links into final link => link_
 
            newtext = linkify_others(newtext, link_)
 
    except Exception:
 
        log.error(traceback.format_exc())
 
        pass
 
    # if we actually did something above
 
    if link_:
 
        # wrap not links into final link => link_
 
        newtext = linkify_others(newtext, link_)
 
    return newtext
 

	
 

	
 
def rst(source):
 
    return literal('<div class="rst-block">%s</div>' %
 
                   MarkupRenderer.rst(source))
 

	
 

	
 
def rst_w_mentions(source):
 
    """
 
    Wrapped rst renderer with @mention highlighting
 

	
 
    :param source:
 
    """
 
    return literal('<div class="rst-block">%s</div>' %
 
                   MarkupRenderer.rst_with_mentions(source))
 

	
 
def short_ref(ref_type, ref_name):
 
    if ref_type == 'rev':
 
        return short_id(ref_name)
 
    return ref_name
 

	
 
def link_to_ref(repo_name, ref_type, ref_name, rev=None):
 
    """
 
    Return full markup for a href to changeset_home for a changeset.
 
    If ref_type is branch it will link to changelog.
 
    ref_name is shortened if ref_type is 'rev'.
 
    if rev is specified show it too, explicitly linking to that revision.
 
    """
 
    txt = short_ref(ref_type, ref_name)
 
    if ref_type == 'branch':
 
        u = url('changelog_home', repo_name=repo_name, branch=ref_name)
 
    else:
 
        u = url('changeset_home', repo_name=repo_name, revision=ref_name)
 
    l = link_to(repo_name + '#' + txt, u)
 
    if rev and ref_type != 'rev':
 
        l = literal('%s (%s)' % (l, link_to(short_id(rev), url('changeset_home', repo_name=repo_name, revision=rev))))
 
    return l
 

	
 
def changeset_status(repo, revision):
 
    return ChangesetStatusModel().get_status(repo, revision)
 

	
 

	
 
def changeset_status_lbl(changeset_status):
 
    return dict(ChangesetStatus.STATUSES).get(changeset_status)
 

	
 

	
 
def get_permission_name(key):
 
    return dict(Permission.PERMS).get(key)
 

	
 

	
 
def journal_filter_help():
 
    return _(textwrap.dedent('''
 
        Example filter terms:
 
            repository:vcs
 
            username:developer
 
            action:*push*
 
            ip:127.0.0.1
 
            date:20120101
 
            date:[20120101100000 TO 20120102]
 

	
 
        Generate wildcards using '*' character:
 
            "repositroy:vcs*" - search everything starting with 'vcs'
 
            "repository:*vcs*" - search for repository containing 'vcs'
 

	
 
        Optional AND / OR operators in queries
 
            "repository:vcs OR repository:test"
 
            "username:test AND repository:test*"
 
    '''))
 

	
 

	
 
def not_mapped_error(repo_name):
 
    flash(_('%s repository is not mapped to db perhaps'
 
            ' it was created or renamed from the filesystem'
 
            ' please run the application again'
 
            ' in order to rescan repositories') % repo_name, category='error')
 

	
 

	
 
def ip_range(ip_addr):
 
    from kallithea.model.db import UserIpMap
 
    s, e = UserIpMap._get_ip_range(ip_addr)
 
    return '%s - %s' % (s, e)
kallithea/lib/utils.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.lib.utils
 
~~~~~~~~~~~~~~~~~~~
 

	
 
Utilities library for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 18, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
import os
 
import re
 
import logging
 
import datetime
 
import traceback
 
import paste
 
import beaker
 
import tarfile
 
import shutil
 
import decorator
 
import warnings
 
from os.path import abspath
 
from os.path import dirname as dn, join as jn
 

	
 
from paste.script.command import Command, BadCommand
 

	
 
from webhelpers.text import collapse, remove_formatting, strip_tags
 
from beaker.cache import _cache_decorate
 

	
 
from kallithea import BRAND
 

	
 
from kallithea.lib.vcs.utils.hgcompat import ui, config
 
from kallithea.lib.vcs.utils.helpers import get_scm
 
from kallithea.lib.vcs.exceptions import VCSError
 

	
 
from kallithea.lib.caching_query import FromCache
 

	
 
from kallithea.model import meta
 
from kallithea.model.db import Repository, User, Ui, \
 
    UserLog, RepoGroup, Setting, CacheInvalidation, UserGroup
 
from kallithea.model.meta import Session
 
from kallithea.model.repo_group import RepoGroupModel
 
from kallithea.lib.utils2 import safe_str, safe_unicode, get_current_authuser
 
from kallithea.lib.vcs.utils.fakemod import create_module
 

	
 
log = logging.getLogger(__name__)
 

	
 
REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
 

	
 

	
 
def recursive_replace(str_, replace=' '):
 
    """
 
    Recursive replace of given sign to just one instance
 

	
 
    :param str_: given string
 
    :param replace: char to find and replace multiple instances
 

	
 
    Examples::
 
    >>> recursive_replace("Mighty---Mighty-Bo--sstones",'-')
 
    'Mighty-Mighty-Bo-sstones'
 
    """
 

	
 
    if str_.find(replace * 2) == -1:
 
        return str_
 
    else:
 
        str_ = str_.replace(replace * 2, replace)
 
        return recursive_replace(str_, replace)
 

	
 

	
 
def repo_name_slug(value):
 
    """
 
    Return slug of name of repository
 
    This function is called on each creation/modification
 
    of repository to prevent bad names in repo
 
    """
 

	
 
    slug = remove_formatting(value)
 
    slug = strip_tags(slug)
 

	
 
    for c in """`?=[]\;'"<>,/~!@#$%^&*()+{}|: """:
 
        slug = slug.replace(c, '-')
 
    slug = recursive_replace(slug, '-')
 
    slug = collapse(slug, '-')
 
    return slug
 

	
 

	
 
#==============================================================================
 
# PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
 
#==============================================================================
 
def get_repo_slug(request):
 
    _repo = request.environ['pylons.routes_dict'].get('repo_name')
 
    if _repo:
 
        _repo = _repo.rstrip('/')
 
    return _repo
 

	
 

	
 
def get_repo_group_slug(request):
 
    _group = request.environ['pylons.routes_dict'].get('group_name')
 
    if _group:
 
        _group = _group.rstrip('/')
 
    return _group
 

	
 

	
 
def get_user_group_slug(request):
 
    _group = request.environ['pylons.routes_dict'].get('id')
 
    try:
 
        _group = UserGroup.get(_group)
 
        if _group:
 
            _group = _group.users_group_name
 
    except Exception:
 
        log.debug(traceback.format_exc())
 
        #catch all failures here
 
        pass
 

	
 
    return _group
 
    _group = UserGroup.get(_group)
 
    if _group:
 
        return _group.users_group_name
 
    return None
 

	
 

	
 
def _extract_id_from_repo_name(repo_name):
 
    if repo_name.startswith('/'):
 
        repo_name = repo_name.lstrip('/')
 
    by_id_match = re.match(r'^_(\d{1,})', repo_name)
 
    if by_id_match:
 
        return by_id_match.groups()[0]
 

	
 

	
 
def get_repo_by_id(repo_name):
 
    """
 
    Extracts repo_name by id from special urls. Example url is _11/repo_name
 

	
 
    :param repo_name:
 
    :return: repo_name if matched else None
 
    """
 
    try:
 
        _repo_id = _extract_id_from_repo_name(repo_name)
 
        if _repo_id:
 
            from kallithea.model.db import Repository
 
            return Repository.get(_repo_id).repo_name
 
    except Exception:
 
        log.debug('Failed to extract repo_name from URL %s' % (
 
                  traceback.format_exc()))
 
        return
 
    _repo_id = _extract_id_from_repo_name(repo_name)
 
    if _repo_id:
 
        from kallithea.model.db import Repository
 
        repo = Repository.get(_repo_id)
 
        if repo:
 
            # TODO: return repo instead of reponame? or would that be a layering violation?
 
            return repo.repo_name
 
    return None
 

	
 

	
 
def action_logger(user, action, repo, ipaddr='', sa=None, commit=False):
 
    """
 
    Action logger for various actions made by users
 

	
 
    :param user: user that made this action, can be a unique username string or
 
        object containing user_id attribute
 
    :param action: action to log, should be on of predefined unique actions for
 
        easy translations
 
    :param repo: string name of repository or object containing repo_id,
 
        that action was made on
 
    :param ipaddr: optional ip address from what the action was made
 
    :param sa: optional sqlalchemy session
 

	
 
    """
 

	
 
    if not sa:
 
        sa = meta.Session()
 
    # if we don't get explicit IP address try to get one from registered user
 
    # in tmpl context var
 
    if not ipaddr:
 
        ipaddr = getattr(get_current_authuser(), 'ip_addr', '')
 

	
 
    try:
 
        if getattr(user, 'user_id', None):
 
            user_obj = User.get(user.user_id)
 
        elif isinstance(user, basestring):
 
            user_obj = User.get_by_username(user)
 
        else:
 
            raise Exception('You have to provide a user object or a username')
 
    if getattr(user, 'user_id', None):
 
        user_obj = User.get(user.user_id)
 
    elif isinstance(user, basestring):
 
        user_obj = User.get_by_username(user)
 
    else:
 
        raise Exception('You have to provide a user object or a username')
 

	
 
        if getattr(repo, 'repo_id', None):
 
            repo_obj = Repository.get(repo.repo_id)
 
            repo_name = repo_obj.repo_name
 
        elif isinstance(repo, basestring):
 
            repo_name = repo.lstrip('/')
 
            repo_obj = Repository.get_by_repo_name(repo_name)
 
        else:
 
            repo_obj = None
 
            repo_name = ''
 
    if getattr(repo, 'repo_id', None):
 
        repo_obj = Repository.get(repo.repo_id)
 
        repo_name = repo_obj.repo_name
 
    elif isinstance(repo, basestring):
 
        repo_name = repo.lstrip('/')
 
        repo_obj = Repository.get_by_repo_name(repo_name)
 
    else:
 
        repo_obj = None
 
        repo_name = ''
 

	
 
        user_log = UserLog()
 
        user_log.user_id = user_obj.user_id
 
        user_log.username = user_obj.username
 
        user_log.action = safe_unicode(action)
 
    user_log = UserLog()
 
    user_log.user_id = user_obj.user_id
 
    user_log.username = user_obj.username
 
    user_log.action = safe_unicode(action)
 

	
 
        user_log.repository = repo_obj
 
        user_log.repository_name = repo_name
 
    user_log.repository = repo_obj
 
    user_log.repository_name = repo_name
 

	
 
        user_log.action_date = datetime.datetime.now()
 
        user_log.user_ip = ipaddr
 
        sa.add(user_log)
 
    user_log.action_date = datetime.datetime.now()
 
    user_log.user_ip = ipaddr
 
    sa.add(user_log)
 

	
 
        log.info('Logging action:%s on %s by user:%s ip:%s' %
 
                 (action, safe_unicode(repo), user_obj, ipaddr))
 
        if commit:
 
            sa.commit()
 
    except Exception:
 
        log.error(traceback.format_exc())
 
        raise
 
    log.info('Logging action:%s on %s by user:%s ip:%s' %
 
             (action, safe_unicode(repo), user_obj, ipaddr))
 
    if commit:
 
        sa.commit()
 

	
 

	
 
def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
 
    """
 
    Scans given path for repos and return (name,(type,path)) tuple
 

	
 
    :param path: path to scan for repositories
 
    :param recursive: recursive search and return names with subdirs in front
 
    """
 

	
 
    # remove ending slash for better results
 
    path = path.rstrip(os.sep)
 
    log.debug('now scanning in %s location recursive:%s...' % (path, recursive))
 

	
 
    def _get_repos(p):
 
        if not os.access(p, os.R_OK) or not os.access(p, os.X_OK):
 
            log.warning('ignoring repo path without access: %s' % (p,))
 
            return
 
        if not os.access(p, os.W_OK):
 
            log.warning('repo path without write access: %s' % (p,))
 
        for dirpath in os.listdir(p):
 
            if os.path.isfile(os.path.join(p, dirpath)):
 
                continue
 
            cur_path = os.path.join(p, dirpath)
 

	
 
            # skip removed repos
 
            if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
 
                continue
 

	
 
            #skip .<somethin> dirs
 
            if dirpath.startswith('.'):
 
                continue
 

	
 
            try:
 
                scm_info = get_scm(cur_path)
 
                yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
 
            except VCSError:
 
                if not recursive:
 
                    continue
 
                #check if this dir containts other repos for recursive scan
 
                rec_path = os.path.join(p, dirpath)
 
                if os.path.isdir(rec_path):
 
                    for inner_scm in _get_repos(rec_path):
 
                        yield inner_scm
 

	
 
    return _get_repos(path)
 

	
 

	
 
def is_valid_repo(repo_name, base_path, scm=None):
 
    """
 
    Returns True if given path is a valid repository False otherwise.
 
    If scm param is given also compare if given scm is the same as expected
 
    from scm parameter
 

	
 
    :param repo_name:
 
    :param base_path:
 
    :param scm:
 

	
 
    :return True: if given path is a valid repository
 
    """
 
    full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
 

	
 
    try:
 
        scm_ = get_scm(full_path)
 
        if scm:
 
            return scm_[0] == scm
 
        return True
 
    except VCSError:
 
        return False
 

	
 

	
 
def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
 
    """
 
    Returns True if given path is a repository group False otherwise
 

	
 
    :param repo_name:
 
    :param base_path:
 
    """
 
    full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
 

	
 
    # check if it's not a repo
 
    if is_valid_repo(repo_group_name, base_path):
 
        return False
 

	
 
    try:
 
        # we need to check bare git repos at higher level
 
        # since we might match branches/hooks/info/objects or possible
 
        # other things inside bare git repo
 
        get_scm(os.path.dirname(full_path))
 
        return False
 
    except VCSError:
 
        pass
 

	
 
    # check if it's a valid path
 
    if skip_path_check or os.path.isdir(full_path):
 
        return True
 

	
 
    return False
 

	
 

	
 
def ask_ok(prompt, retries=4, complaint='Yes or no please!'):
 
    while True:
 
        ok = raw_input(prompt)
 
        if ok in ('y', 'ye', 'yes'):
 
            return True
 
        if ok in ('n', 'no', 'nop', 'nope'):
 
            return False
 
        retries = retries - 1
 
        if retries < 0:
 
            raise IOError
 
        print complaint
 

	
 
#propagated from mercurial documentation
 
ui_sections = ['alias', 'auth',
 
                'decode/encode', 'defaults',
 
                'diff', 'email',
 
                'extensions', 'format',
 
                'merge-patterns', 'merge-tools',
 
                'hooks', 'http_proxy',
 
                'smtp', 'patch',
 
                'paths', 'profiling',
 
                'server', 'trusted',
 
                'ui', 'web', ]
 

	
 

	
 
def make_ui(read_from='file', path=None, checkpaths=True, clear_session=True):
 
    """
 
    A function that will read python rc files or database
 
    and make an mercurial ui object from read options
 

	
 
    :param path: path to mercurial config file
 
    :param checkpaths: check the path
 
    :param read_from: read from 'file' or 'db'
 
    """
 

	
 
    baseui = ui.ui()
 

	
 
    # clean the baseui object
 
    baseui._ocfg = config.config()
 
    baseui._ucfg = config.config()
 
    baseui._tcfg = config.config()
 

	
 
    if read_from == 'file':
 
        if not os.path.isfile(path):
 
            log.debug('hgrc file is not present at %s, skipping...' % path)
 
            return False
 
        log.debug('reading hgrc from %s' % path)
 
        cfg = config.config()
 
        cfg.read(path)
 
        for section in ui_sections:
 
            for k, v in cfg.items(section):
 
                log.debug('settings ui from file: [%s] %s=%s' % (section, k, v))
 
                baseui.setconfig(safe_str(section), safe_str(k), safe_str(v))
 

	
 
    elif read_from == 'db':
 
        sa = meta.Session()
 
        ret = sa.query(Ui)\
 
            .options(FromCache("sql_cache_short", "get_hg_ui_settings"))\
 
            .all()
 

	
 
        hg_ui = ret
 
        for ui_ in hg_ui:
 
            if ui_.ui_active:
 
                ui_val = safe_str(ui_.ui_value)
 
                if ui_.ui_section == 'hooks' and BRAND != 'kallithea' and ui_val.startswith('python:' + BRAND + '.lib.hooks.'):
 
                    ui_val = ui_val.replace('python:' + BRAND + '.lib.hooks.', 'python:kallithea.lib.hooks.')
 
                log.debug('settings ui from db: [%s] %s=%s', ui_.ui_section,
 
                          ui_.ui_key, ui_val)
 
                baseui.setconfig(safe_str(ui_.ui_section), safe_str(ui_.ui_key),
 
                                 ui_val)
 
            if ui_.ui_key == 'push_ssl':
 
                # force set push_ssl requirement to False, kallithea
 
                # handles that
 
                baseui.setconfig(safe_str(ui_.ui_section), safe_str(ui_.ui_key),
 
                                 False)
 
        if clear_session:
 
            meta.Session.remove()
 
    return baseui
 

	
 

	
 
def set_app_settings(config):
 
    """
 
    Updates pylons config with new settings from database
 

	
 
    :param config:
 
    """
 
    hgsettings = Setting.get_app_settings()
 

	
 
    for k, v in hgsettings.items():
 
        config[k] = v
 

	
 

	
 
def set_vcs_config(config):
 
    """
 
    Patch VCS config with some Kallithea specific stuff
 

	
 
    :param config: kallithea.CONFIG
 
    """
 
    import kallithea
 
    from kallithea.lib.vcs import conf
 
    from kallithea.lib.utils2 import aslist
 
    conf.settings.BACKENDS = {
 
        'hg': 'kallithea.lib.vcs.backends.hg.MercurialRepository',
 
        'git': 'kallithea.lib.vcs.backends.git.GitRepository',
 
    }
 

	
 
    conf.settings.GIT_EXECUTABLE_PATH = config.get('git_path', 'git')
 
    conf.settings.GIT_REV_FILTER = config.get('git_rev_filter', '--all').strip()
 
    conf.settings.DEFAULT_ENCODINGS = aslist(config.get('default_encoding',
 
                                                        'utf8'), sep=',')
 

	
 

	
 
def map_groups(path):
 
    """
 
    Given a full path to a repository, create all nested groups that this
 
    repo is inside. This function creates parent-child relationships between
 
    groups and creates default perms for all new groups.
 

	
 
    :param paths: full path to repository
 
    """
 
    sa = meta.Session()
 
    groups = path.split(Repository.url_sep())
 
    parent = None
 
    group = None
 

	
 
    # last element is repo in nested groups structure
 
    groups = groups[:-1]
 
    rgm = RepoGroupModel(sa)
 
    owner = User.get_first_admin()
 
    for lvl, group_name in enumerate(groups):
 
        group_name = '/'.join(groups[:lvl] + [group_name])
 
        group = RepoGroup.get_by_group_name(group_name)
 
        desc = '%s group' % group_name
 

	
 
        # skip folders that are now removed repos
 
        if REMOVED_REPO_PAT.match(group_name):
 
            break
 

	
 
        if group is None:
 
            log.debug('creating group level: %s group_name: %s'
 
                      % (lvl, group_name))
 
            group = RepoGroup(group_name, parent)
 
            group.group_description = desc
 
            group.user = owner
 
            sa.add(group)
 
            perm_obj = rgm._create_default_perms(group)
 
            sa.add(perm_obj)
 
            sa.flush()
 

	
 
        parent = group
 
    return group
 

	
 

	
 
def repo2db_mapper(initial_repo_list, remove_obsolete=False,
 
                   install_git_hook=False, user=None):
 
    """
 
    maps all repos given in initial_repo_list, non existing repositories
 
    are created, if remove_obsolete is True it also check for db entries
 
    that are not in initial_repo_list and removes them.
 

	
 
    :param initial_repo_list: list of repositories found by scanning methods
 
    :param remove_obsolete: check for obsolete entries in database
 
    :param install_git_hook: if this is True, also check and install githook
 
        for a repo if missing
 
    """
 
    from kallithea.model.repo import RepoModel
 
    from kallithea.model.scm import ScmModel
 
    sa = meta.Session()
 
    repo_model = RepoModel()
 
    if user is None:
 
        user = User.get_first_admin()
 
    added = []
 

	
 
    ##creation defaults
 
    defs = Setting.get_default_repo_settings(strip_prefix=True)
 
    enable_statistics = defs.get('repo_enable_statistics')
 
    enable_locking = defs.get('repo_enable_locking')
 
    enable_downloads = defs.get('repo_enable_downloads')
 
    private = defs.get('repo_private')
 

	
 
    for name, repo in initial_repo_list.items():
 
        group = map_groups(name)
 
        unicode_name = safe_unicode(name)
 
        db_repo = repo_model.get_by_repo_name(unicode_name)
 
        # found repo that is on filesystem not in Kallithea database
 
        if not db_repo:
 
            log.info('repository %s not found, creating now' % name)
 
            added.append(name)
 
            desc = (repo.description
 
                    if repo.description != 'unknown'
 
                    else '%s repository' % name)
 

	
 
            new_repo = repo_model._create_repo(
 
                repo_name=name,
 
                repo_type=repo.alias,
 
                description=desc,
 
                repo_group=getattr(group, 'group_id', None),
 
                owner=user,
 
                enable_locking=enable_locking,
 
                enable_downloads=enable_downloads,
 
                enable_statistics=enable_statistics,
 
                private=private,
 
                state=Repository.STATE_CREATED
 
            )
 
            sa.commit()
 
            # we added that repo just now, and make sure it has githook
 
            # installed, and updated server info
 
            if new_repo.repo_type == 'git':
 
                git_repo = new_repo.scm_instance
 
                ScmModel().install_git_hook(git_repo)
 
                # update repository server-info
 
                log.debug('Running update server info')
 
                git_repo._update_server_info()
 
            new_repo.update_changeset_cache()
 
        elif install_git_hook:
 
            if db_repo.repo_type == 'git':
 
                ScmModel().install_git_hook(db_repo.scm_instance)
 

	
 
    removed = []
 
    if remove_obsolete:
 
        # remove from database those repositories that are not in the filesystem
 
        for repo in sa.query(Repository).all():
 
            if repo.repo_name not in initial_repo_list.keys():
 
                log.debug("Removing non-existing repository found in db `%s`" %
 
                          repo.repo_name)
 
                try:
 
                    removed.append(repo.repo_name)
 
                    RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
 
                    sa.commit()
 
                except Exception:
 
                    #don't hold further removals on error
 
                    log.error(traceback.format_exc())
 
                    sa.rollback()
 
    return added, removed
 

	
 

	
 
# set cache regions for beaker so celery can utilise it
 
def add_cache(settings):
 
    cache_settings = {'regions': None}
 
    for key in settings.keys():
 
        for prefix in ['beaker.cache.', 'cache.']:
 
            if key.startswith(prefix):
 
                name = key.split(prefix)[1].strip()
 
                cache_settings[name] = settings[key].strip()
 
    if cache_settings['regions']:
 
        for region in cache_settings['regions'].split(','):
 
            region = region.strip()
 
            region_settings = {}
 
            for key, value in cache_settings.items():
 
                if key.startswith(region):
 
                    region_settings[key.split('.')[1]] = value
 
            region_settings['expire'] = int(region_settings.get('expire',
 
                                                                60))
 
            region_settings.setdefault('lock_dir',
 
                                       cache_settings.get('lock_dir'))
 
            region_settings.setdefault('data_dir',
 
                                       cache_settings.get('data_dir'))
 

	
 
            if 'type' not in region_settings:
 
                region_settings['type'] = cache_settings.get('type',
 
                                                             'memory')
 
            beaker.cache.cache_regions[region] = region_settings
 

	
 

	
 
def load_rcextensions(root_path):
 
    import kallithea
 
    from kallithea.config import conf
 

	
 
    path = os.path.join(root_path, 'rcextensions', '__init__.py')
 
    if os.path.isfile(path):
 
        rcext = create_module('rc', path)
 
        EXT = kallithea.EXTENSIONS = rcext
 
        log.debug('Found rcextensions now loading %s...' % rcext)
 

	
 
        # Additional mappings that are not present in the pygments lexers
 
        conf.LANGUAGES_EXTENSIONS_MAP.update(getattr(EXT, 'EXTRA_MAPPINGS', {}))
 

	
 
        #OVERRIDE OUR EXTENSIONS FROM RC-EXTENSIONS (if present)
 

	
 
        if getattr(EXT, 'INDEX_EXTENSIONS', []):
 
            log.debug('settings custom INDEX_EXTENSIONS')
 
            conf.INDEX_EXTENSIONS = getattr(EXT, 'INDEX_EXTENSIONS', [])
 

	
 
        #ADDITIONAL MAPPINGS
 
        log.debug('adding extra into INDEX_EXTENSIONS')
 
        conf.INDEX_EXTENSIONS.extend(getattr(EXT, 'EXTRA_INDEX_EXTENSIONS', []))
 

	
 
        # auto check if the module is not missing any data, set to default if is
 
        # this will help autoupdate new feature of rcext module
 
        #from kallithea.config import rcextensions
 
        #for k in dir(rcextensions):
 
        #    if not k.startswith('_') and not hasattr(EXT, k):
 
        #        setattr(EXT, k, getattr(rcextensions, k))
 

	
 

	
 
def get_custom_lexer(extension):
 
    """
 
    returns a custom lexer if it's defined in rcextensions module, or None
 
    if there's no custom lexer defined
 
    """
 
    import kallithea
 
    from pygments import lexers
 
    #check if we didn't define this extension as other lexer
 
    if kallithea.EXTENSIONS and extension in kallithea.EXTENSIONS.EXTRA_LEXERS:
 
        _lexer_name = kallithea.EXTENSIONS.EXTRA_LEXERS[extension]
 
        return lexers.get_lexer_by_name(_lexer_name)
 

	
 

	
 
#==============================================================================
 
# TEST FUNCTIONS AND CREATORS
 
#==============================================================================
 
def create_test_index(repo_location, config, full_index):
 
    """
 
    Makes default test index
 

	
 
    :param config: test config
 
    :param full_index:
 
    """
 

	
 
    from kallithea.lib.indexers.daemon import WhooshIndexingDaemon
 
    from kallithea.lib.pidlock import DaemonLock, LockHeld
 

	
 
    repo_location = repo_location
 

	
 
    index_location = os.path.join(config['app_conf']['index_dir'])
 
    if not os.path.exists(index_location):
 
        os.makedirs(index_location)
 

	
 
    try:
 
        l = DaemonLock(file_=jn(dn(index_location), 'make_index.lock'))
 
        WhooshIndexingDaemon(index_location=index_location,
 
                             repo_location=repo_location)\
 
            .run(full_index=full_index)
 
        l.release()
 
    except LockHeld:
 
        pass
 

	
 

	
 
def create_test_env(repos_test_path, config):
 
    """
 
    Makes a fresh database and
 
    install test repository into tmp dir
 
    """
 
    from kallithea.lib.db_manage import DbManage
 
    from kallithea.tests import HG_REPO, GIT_REPO, TESTS_TMP_PATH
 

	
 
    # PART ONE create db
 
    dbconf = config['sqlalchemy.db1.url']
 
    log.debug('making test db %s' % dbconf)
 

	
 
    # create test dir if it doesn't exist
 
    if not os.path.isdir(repos_test_path):
 
        log.debug('Creating testdir %s' % repos_test_path)
 
        os.makedirs(repos_test_path)
 

	
 
    dbmanage = DbManage(log_sql=True, dbconf=dbconf, root=config['here'],
 
                        tests=True)
 
    dbmanage.create_tables(override=True)
 
    # for tests dynamically set new root paths based on generated content
 
    dbmanage.create_settings(dbmanage.config_prompt(repos_test_path))
 
    dbmanage.create_default_user()
 
    dbmanage.admin_prompt()
 
    dbmanage.create_permissions()
 
    dbmanage.populate_default_permissions()
 
    Session().commit()
 
    # PART TWO make test repo
 
    log.debug('making test vcs repositories')
 

	
 
    idx_path = config['app_conf']['index_dir']
 
    data_path = config['app_conf']['cache_dir']
 

	
 
    #clean index and data
 
    if idx_path and os.path.exists(idx_path):
 
        log.debug('remove %s' % idx_path)
 
        shutil.rmtree(idx_path)
 

	
 
    if data_path and os.path.exists(data_path):
 
        log.debug('remove %s' % data_path)
 
        shutil.rmtree(data_path)
 

	
 
    #CREATE DEFAULT TEST REPOS
 
    cur_dir = dn(dn(abspath(__file__)))
 
    tar = tarfile.open(jn(cur_dir, 'tests', 'fixtures', "vcs_test_hg.tar.gz"))
 
    tar.extractall(jn(TESTS_TMP_PATH, HG_REPO))
 
    tar.close()
 

	
 
    cur_dir = dn(dn(abspath(__file__)))
 
    tar = tarfile.open(jn(cur_dir, 'tests', 'fixtures', "vcs_test_git.tar.gz"))
 
    tar.extractall(jn(TESTS_TMP_PATH, GIT_REPO))
 
    tar.close()
 

	
 
    #LOAD VCS test stuff
 
    from kallithea.tests.vcs import setup_package
 
    setup_package()
 

	
 

	
 
#==============================================================================
 
# PASTER COMMANDS
 
#==============================================================================
 
class BasePasterCommand(Command):
 
    """
 
    Abstract Base Class for paster commands.
 

	
 
    The celery commands are somewhat aggressive about loading
 
    celery.conf, and since our module sets the `CELERY_LOADER`
 
    environment variable to our loader, we have to bootstrap a bit and
 
    make sure we've had a chance to load the pylons config off of the
 
    command line, otherwise everything fails.
 
    """
 
    min_args = 1
 
    min_args_error = "Please provide a paster config file as an argument."
 
    takes_config_file = 1
 
    requires_config_file = True
 

	
 
    def notify_msg(self, msg, log=False):
 
        """Make a notification to user, additionally if logger is passed
 
        it logs this action using given logger
 

	
 
        :param msg: message that will be printed to user
 
        :param log: logging instance, to use to additionally log this message
 

	
 
        """
 
        if log and isinstance(log, logging):
 
            log(msg)
 

	
 
    def run(self, args):
 
        """
 
        Overrides Command.run
 

	
 
        Checks for a config file argument and loads it.
 
        """
 
        if len(args) < self.min_args:
 
            raise BadCommand(
 
                self.min_args_error % {'min_args': self.min_args,
 
                                       'actual_args': len(args)})
 

	
 
        # Decrement because we're going to lob off the first argument.
 
        # @@ This is hacky
 
        self.min_args -= 1
 
        self.bootstrap_config(args[0])
 
        self.update_parser()
 
        return super(BasePasterCommand, self).run(args[1:])
 

	
 
    def update_parser(self):
 
        """
 
        Abstract method.  Allows for the class's parser to be updated
 
        before the superclass's `run` method is called.  Necessary to
 
        allow options/arguments to be passed through to the underlying
 
        celery command.
 
        """
 
        raise NotImplementedError("Abstract Method.")
 

	
 
    def bootstrap_config(self, conf):
 
        """
 
        Loads the pylons configuration.
 
        """
 
        from pylons import config as pylonsconfig
 

	
 
        self.path_to_ini_file = os.path.realpath(conf)
 
        conf = paste.deploy.appconfig('config:' + self.path_to_ini_file)
 
        pylonsconfig.init_app(conf.global_conf, conf.local_conf)
 

	
 
    def _init_session(self):
 
        """
 
        Inits SqlAlchemy Session
 
        """
 
        logging.config.fileConfig(self.path_to_ini_file)
 
        from pylons import config
 
        from kallithea.model import init_model
 
        from kallithea.lib.utils2 import engine_from_config
 

	
 
        #get to remove repos !!
 
        add_cache(config)
 
        engine = engine_from_config(config, 'sqlalchemy.db1.')
 
        init_model(engine)
 

	
 

	
 
def check_git_version():
 
    """
 
    Checks what version of git is installed in system, and issues a warning
 
    if it's too old for Kallithea to properly work.
 
    """
 
    from kallithea import BACKENDS
 
    from kallithea.lib.vcs.backends.git.repository import GitRepository
 
    from kallithea.lib.vcs.conf import settings
 
    from distutils.version import StrictVersion
 

	
 
    if 'git' not in BACKENDS:
 
        return None
 

	
 
    stdout, stderr = GitRepository._run_git_command('--version', _bare=True,
 
                                                    _safe=True)
 

	
 
    ver = (stdout.split(' ')[-1] or '').strip() or '0.0.0'
 
    if len(ver.split('.')) > 3:
 
        #StrictVersion needs to be only 3 element type
 
        ver = '.'.join(ver.split('.')[:3])
 
    try:
 
        _ver = StrictVersion(ver)
 
    except Exception:
 
    except ValueError:
 
        _ver = StrictVersion('0.0.0')
 
        stderr = traceback.format_exc()
 

	
 
    req_ver = '1.7.4'
 
    to_old_git = False
 
    if  _ver < StrictVersion(req_ver):
 
        to_old_git = True
 

	
 
    if 'git' in BACKENDS:
 
        log.debug('Git executable: "%s" version detected: %s'
 
                  % (settings.GIT_EXECUTABLE_PATH, stdout))
 
        if stderr:
 
            log.warning('Unable to detect git version, org error was: %r' % stderr)
 
        elif to_old_git:
 
            log.warning('Kallithea detected git version %s, which is too old '
 
                        'for the system to function properly. Make sure '
 
                        'its version is at least %s' % (ver, req_ver))
 
    return _ver
 

	
 

	
 
@decorator.decorator
 
def jsonify(func, *args, **kwargs):
 
    """Action decorator that formats output for JSON
 

	
 
    Given a function that will return content, this decorator will turn
 
    the result into JSON, with a content-type of 'application/json' and
 
    output it.
 

	
 
    """
 
    from pylons.decorators.util import get_pylons
 
    from kallithea.lib.compat import json
 
    pylons = get_pylons(args)
 
    pylons.response.headers['Content-Type'] = 'application/json; charset=utf-8'
 
    data = func(*args, **kwargs)
 
    if isinstance(data, (list, tuple)):
 
        msg = "JSON responses with Array envelopes are susceptible to " \
 
              "cross-site data leak attacks, see " \
 
              "http://wiki.pylonshq.com/display/pylonsfaq/Warnings"
 
        warnings.warn(msg, Warning, 2)
 
        log.warning(msg)
 
    log.debug("Returning JSON wrapped action output")
 
    return json.dumps(data, encoding='utf-8')
 

	
 

	
 
def conditional_cache(region, prefix, condition, func):
 
    """
 

	
 
    Conditional caching function use like::
 
        def _c(arg):
 
            #heavy computation function
 
            return data
 

	
 
        # denpending from condition the compute is wrapped in cache or not
 
        compute = conditional_cache('short_term', 'cache_desc', codnition=True, func=func)
 
        return compute(arg)
 

	
 
    :param region: name of cache region
 
    :param prefix: cache region prefix
 
    :param condition: condition for cache to be triggered, and return data cached
 
    :param func: wrapped heavy function to compute
 

	
 
    """
 
    wrapped = func
 
    if condition:
 
        log.debug('conditional_cache: True, wrapping call of '
 
                  'func: %s into %s region cache' % (region, func))
 
        wrapped = _cache_decorate((prefix,), None, None, region)(func)
 

	
 
    return wrapped
kallithea/lib/utils2.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.lib.utils
 
~~~~~~~~~~~~~~~~~~~
 

	
 
Some simple helper functions
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Jan 5, 2011
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 

	
 
import os
 
import re
 
import sys
 
import time
 
import uuid
 
import datetime
 
import traceback
 
import webob
 
import urllib
 
import urlobject
 

	
 
from pylons.i18n.translation import _, ungettext
 
from kallithea.lib.vcs.utils.lazy import LazyProperty
 
from kallithea.lib.compat import json
 

	
 

	
 
def __get_lem():
 
    """
 
    Get language extension map based on what's inside pygments lexers
 
    """
 
    from pygments import lexers
 
    from string import lower
 
    from collections import defaultdict
 

	
 
    d = defaultdict(lambda: [])
 

	
 
    def __clean(s):
 
        s = s.lstrip('*')
 
        s = s.lstrip('.')
 

	
 
        if s.find('[') != -1:
 
            exts = []
 
            start, stop = s.find('['), s.find(']')
 

	
 
            for suffix in s[start + 1:stop]:
 
                exts.append(s[:s.find('[')] + suffix)
 
            return map(lower, exts)
 
        else:
 
            return map(lower, [s])
 

	
 
    for lx, t in sorted(lexers.LEXERS.items()):
 
        m = map(__clean, t[-2])
 
        if m:
 
            m = reduce(lambda x, y: x + y, m)
 
            for ext in m:
 
                desc = lx.replace('Lexer', '')
 
                d[ext].append(desc)
 

	
 
    return dict(d)
 

	
 

	
 
def str2bool(_str):
 
    """
 
    returs True/False value from given string, it tries to translate the
 
    string into boolean
 

	
 
    :param _str: string value to translate into boolean
 
    :rtype: boolean
 
    :returns: boolean from given string
 
    """
 
    if _str is None:
 
        return False
 
    if _str in (True, False):
 
        return _str
 
    _str = str(_str).strip().lower()
 
    return _str in ('t', 'true', 'y', 'yes', 'on', '1')
 

	
 

	
 
def aslist(obj, sep=None, strip=True):
 
    """
 
    Returns given string separated by sep as list
 

	
 
    :param obj:
 
    :param sep:
 
    :param strip:
 
    """
 
    if isinstance(obj, (basestring)):
 
        lst = obj.split(sep)
 
        if strip:
 
            lst = [v.strip() for v in lst]
 
        return lst
 
    elif isinstance(obj, (list, tuple)):
 
        return obj
 
    elif obj is None:
 
        return []
 
    else:
 
        return [obj]
 

	
 

	
 
def convert_line_endings(line, mode):
 
    """
 
    Converts a given line  "line end" accordingly to given mode
 

	
 
    Available modes are::
 
        0 - Unix
 
        1 - Mac
 
        2 - DOS
 

	
 
    :param line: given line to convert
 
    :param mode: mode to convert to
 
    :rtype: str
 
    :return: converted line according to mode
 
    """
 
    from string import replace
 

	
 
    if mode == 0:
 
            line = replace(line, '\r\n', '\n')
 
            line = replace(line, '\r', '\n')
 
    elif mode == 1:
 
            line = replace(line, '\r\n', '\r')
 
            line = replace(line, '\n', '\r')
 
    elif mode == 2:
 
            line = re.sub("\r(?!\n)|(?<!\r)\n", "\r\n", line)
 
    return line
 

	
 

	
 
def detect_mode(line, default):
 
    """
 
    Detects line break for given line, if line break couldn't be found
 
    given default value is returned
 

	
 
    :param line: str line
 
    :param default: default
 
    :rtype: int
 
    :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
 
    """
 
    if line.endswith('\r\n'):
 
        return 2
 
    elif line.endswith('\n'):
 
        return 0
 
    elif line.endswith('\r'):
 
        return 1
 
    else:
 
        return default
 

	
 

	
 
def generate_api_key(username, salt=None):
 
    """
 
    Generates unique API key for given username, if salt is not given
 
    it'll be generated from some random string
 

	
 
    :param username: username as string
 
    :param salt: salt to hash generate KEY
 
    :rtype: str
 
    :returns: sha1 hash from username+salt
 
    """
 
    from tempfile import _RandomNameSequence
 
    import hashlib
 

	
 
    if salt is None:
 
        salt = _RandomNameSequence().next()
 

	
 
    return hashlib.sha1(username + salt).hexdigest()
 

	
 

	
 
def safe_int(val, default=None):
 
    """
 
    Returns int() of val if val is not convertable to int use default
 
    instead
 

	
 
    :param val:
 
    :param default:
 
    """
 

	
 
    try:
 
        val = int(val)
 
    except (ValueError, TypeError):
 
        val = default
 

	
 
    return val
 

	
 

	
 
def safe_unicode(str_, from_encoding=None):
 
    """
 
    safe unicode function. Does few trick to turn str_ into unicode
 

	
 
    In case of UnicodeDecode error we try to return it with encoding detected
 
    by chardet library if it fails fallback to unicode with errors replaced
 

	
 
    :param str_: string to decode
 
    :rtype: unicode
 
    :returns: unicode object
 
    """
 
    if isinstance(str_, unicode):
 
        return str_
 

	
 
    if not from_encoding:
 
        import kallithea
 
        DEFAULT_ENCODINGS = aslist(kallithea.CONFIG.get('default_encoding',
 
                                                        'utf8'), sep=',')
 
        from_encoding = DEFAULT_ENCODINGS
 

	
 
    if not isinstance(from_encoding, (list, tuple)):
 
        from_encoding = [from_encoding]
 

	
 
    try:
 
        return unicode(str_)
 
    except UnicodeDecodeError:
 
        pass
 

	
 
    for enc in from_encoding:
 
        try:
 
            return unicode(str_, enc)
 
        except UnicodeDecodeError:
 
            pass
 

	
 
    try:
 
        import chardet
 
        encoding = chardet.detect(str_)['encoding']
 
        if encoding is None:
 
            raise Exception()
 
        return str_.decode(encoding)
 
    except (ImportError, UnicodeDecodeError, Exception):
 
        return unicode(str_, from_encoding[0], 'replace')
 

	
 

	
 
def safe_str(unicode_, to_encoding=None):
 
    """
 
    safe str function. Does few trick to turn unicode_ into string
 

	
 
    In case of UnicodeEncodeError we try to return it with encoding detected
 
    by chardet library if it fails fallback to string with errors replaced
 

	
 
    :param unicode_: unicode to encode
 
    :rtype: str
 
    :returns: str object
 
    """
 

	
 
    # if it's not basestr cast to str
 
    if not isinstance(unicode_, basestring):
 
        return str(unicode_)
 

	
 
    if isinstance(unicode_, str):
 
        return unicode_
 

	
 
    if not to_encoding:
 
        import kallithea
 
        DEFAULT_ENCODINGS = aslist(kallithea.CONFIG.get('default_encoding',
 
                                                        'utf8'), sep=',')
 
        to_encoding = DEFAULT_ENCODINGS
 

	
 
    if not isinstance(to_encoding, (list, tuple)):
 
        to_encoding = [to_encoding]
 

	
 
    for enc in to_encoding:
 
        try:
 
            return unicode_.encode(enc)
 
        except UnicodeEncodeError:
 
            pass
 

	
 
    try:
 
        import chardet
 
        encoding = chardet.detect(unicode_)['encoding']
 
        if encoding is None:
 
            raise UnicodeEncodeError()
 

	
 
        return unicode_.encode(encoding)
 
    except (ImportError, UnicodeEncodeError):
 
        return unicode_.encode(to_encoding[0], 'replace')
 

	
 

	
 
def remove_suffix(s, suffix):
 
    if s.endswith(suffix):
 
        s = s[:-1 * len(suffix)]
 
    return s
 

	
 

	
 
def remove_prefix(s, prefix):
 
    if s.startswith(prefix):
 
        s = s[len(prefix):]
 
    return s
 

	
 

	
 
def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
 
    """
 
    Custom engine_from_config functions that makes sure we use NullPool for
 
    file based sqlite databases. This prevents errors on sqlite. This only
 
    applies to sqlalchemy versions < 0.7.0
 

	
 
    """
 
    import sqlalchemy
 
    from sqlalchemy import engine_from_config as efc
 
    import logging
 

	
 
    if int(sqlalchemy.__version__.split('.')[1]) < 7:
 

	
 
        # This solution should work for sqlalchemy < 0.7.0, and should use
 
        # proxy=TimerProxy() for execution time profiling
 

	
 
        from sqlalchemy.pool import NullPool
 
        url = configuration[prefix + 'url']
 

	
 
        if url.startswith('sqlite'):
 
            kwargs.update({'poolclass': NullPool})
 
        return efc(configuration, prefix, **kwargs)
 
    else:
 
        import time
 
        from sqlalchemy import event
 

	
 
        log = logging.getLogger('sqlalchemy.engine')
 
        BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = xrange(30, 38)
 
        engine = efc(configuration, prefix, **kwargs)
 

	
 
        def color_sql(sql):
 
            COLOR_SEQ = "\033[1;%dm"
 
            COLOR_SQL = YELLOW
 
            normal = '\x1b[0m'
 
            return ''.join([COLOR_SEQ % COLOR_SQL, sql, normal])
 

	
 
        if configuration['debug']:
 
            #attach events only for debug configuration
 

	
 
            def before_cursor_execute(conn, cursor, statement,
 
                                    parameters, context, executemany):
 
                context._query_start_time = time.time()
 
                log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
 

	
 
            def after_cursor_execute(conn, cursor, statement,
 
                                    parameters, context, executemany):
 
                total = time.time() - context._query_start_time
 
                log.info(color_sql("<<<<< TOTAL TIME: %f <<<<<" % total))
 

	
 
            event.listen(engine, "before_cursor_execute",
 
                         before_cursor_execute)
 
            event.listen(engine, "after_cursor_execute",
 
                         after_cursor_execute)
 

	
 
    return engine
 

	
 

	
 
def age(prevdate, show_short_version=False, now=None):
 
    """
 
    turns a datetime into an age string.
 
    If show_short_version is True, then it will generate a not so accurate but shorter string,
 
    example: 2days ago, instead of 2 days and 23 hours ago.
 

	
 
    :param prevdate: datetime object
 
    :param show_short_version: if it should aproximate the date and return a shorter string
 
    :rtype: unicode
 
    :returns: unicode words describing age
 
    """
 
    now = now or datetime.datetime.now()
 
    order = ['year', 'month', 'day', 'hour', 'minute', 'second']
 
    deltas = {}
 
    future = False
 

	
 
    if prevdate > now:
 
        now, prevdate = prevdate, now
 
        future = True
 
    if future:
 
        prevdate = prevdate.replace(microsecond=0)
 
    # Get date parts deltas
 
    from dateutil import relativedelta
 
    for part in order:
 
        d = relativedelta.relativedelta(now, prevdate)
 
        deltas[part] = getattr(d, part + 's')
 

	
 
    # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
 
    # not 1 hour, -59 minutes and -59 seconds)
 
    for num, length in [(5, 60), (4, 60), (3, 24)]:  # seconds, minutes, hours
 
        part = order[num]
 
        carry_part = order[num - 1]
 

	
 
        if deltas[part] < 0:
 
            deltas[part] += length
 
            deltas[carry_part] -= 1
 

	
 
    # Same thing for days except that the increment depends on the (variable)
 
    # number of days in the month
 
    month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
 
    if deltas['day'] < 0:
 
        if prevdate.month == 2 and (prevdate.year % 4 == 0 and
 
            (prevdate.year % 100 != 0 or prevdate.year % 400 == 0)):
 
            deltas['day'] += 29
 
        else:
 
            deltas['day'] += month_lengths[prevdate.month - 1]
 

	
 
        deltas['month'] -= 1
 

	
 
    if deltas['month'] < 0:
 
        deltas['month'] += 12
 
        deltas['year'] -= 1
 

	
 
    # Format the result
 
    fmt_funcs = {
 
        'year': lambda d: ungettext(u'%d year', '%d years', d) % d,
 
        'month': lambda d: ungettext(u'%d month', '%d months', d) % d,
 
        'day': lambda d: ungettext(u'%d day', '%d days', d) % d,
 
        'hour': lambda d: ungettext(u'%d hour', '%d hours', d) % d,
 
        'minute': lambda d: ungettext(u'%d minute', '%d minutes', d) % d,
 
        'second': lambda d: ungettext(u'%d second', '%d seconds', d) % d,
 
    }
 

	
 
    for i, part in enumerate(order):
 
        value = deltas[part]
 
        if value == 0:
 
            continue
 

	
 
        if i < 5:
 
            sub_part = order[i + 1]
 
            sub_value = deltas[sub_part]
 
        else:
 
            sub_value = 0
 

	
 
        if sub_value == 0 or show_short_version:
 
            if future:
 
                return _(u'in %s') % fmt_funcs[part](value)
 
            else:
 
                return _(u'%s ago') % fmt_funcs[part](value)
 
        if future:
 
            return _(u'in %s and %s') % (fmt_funcs[part](value),
 
                fmt_funcs[sub_part](sub_value))
 
        else:
 
            return _(u'%s and %s ago') % (fmt_funcs[part](value),
 
                fmt_funcs[sub_part](sub_value))
 

	
 
    return _(u'just now')
 

	
 

	
 
def uri_filter(uri):
 
    """
 
    Removes user:password from given url string
 

	
 
    :param uri:
 
    :rtype: unicode
 
    :returns: filtered list of strings
 
    """
 
    if not uri:
 
        return ''
 

	
 
    proto = ''
 

	
 
    for pat in ('https://', 'http://'):
 
        if uri.startswith(pat):
 
            uri = uri[len(pat):]
 
            proto = pat
 
            break
 

	
 
    # remove passwords and username
 
    uri = uri[uri.find('@') + 1:]
 

	
 
    # get the port
 
    cred_pos = uri.find(':')
 
    if cred_pos == -1:
 
        host, port = uri, None
 
    else:
 
        host, port = uri[:cred_pos], uri[cred_pos + 1:]
 

	
 
    return filter(None, [proto, host, port])
 

	
 

	
 
def credentials_filter(uri):
 
    """
 
    Returns a url with removed credentials
 

	
 
    :param uri:
 
    """
 

	
 
    uri = uri_filter(uri)
 
    #check if we have port
 
    if len(uri) > 2 and uri[2]:
 
        uri[2] = ':' + uri[2]
 

	
 
    return ''.join(uri)
 

	
 

	
 
def get_clone_url(uri_tmpl, qualifed_home_url, repo_name, repo_id, **override):
 
    parsed_url = urlobject.URLObject(qualifed_home_url)
 
    decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
 
    args = {
 
        'scheme': parsed_url.scheme,
 
        'user': '',
 
        'netloc': parsed_url.netloc+decoded_path,  # path if we use proxy-prefix
 
        'prefix': decoded_path,
 
        'repo': repo_name,
 
        'repoid': str(repo_id)
 
    }
 
    args.update(override)
 
    args['user'] = urllib.quote(safe_str(args['user']))
 

	
 
    for k, v in args.items():
 
        uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
 

	
 
    # remove leading @ sign if it's present. Case of empty user
 
    url_obj = urlobject.URLObject(uri_tmpl)
 
    url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
 

	
 
    return safe_unicode(url)
 

	
 

	
 
def get_changeset_safe(repo, rev):
 
    """
 
    Safe version of get_changeset if this changeset doesn't exists for a
 
    repo it returns a Dummy one instead
 

	
 
    :param repo:
 
    :param rev:
 
    """
 
    from kallithea.lib.vcs.backends.base import BaseRepository
 
    from kallithea.lib.vcs.exceptions import RepositoryError
 
    from kallithea.lib.vcs.backends.base import EmptyChangeset
 
    if not isinstance(repo, BaseRepository):
 
        raise Exception('You must pass an Repository '
 
                        'object as first argument got %s', type(repo))
 

	
 
    try:
 
        cs = repo.get_changeset(rev)
 
    except (RepositoryError, LookupError):
 
        cs = EmptyChangeset(requested_revision=rev)
 
    return cs
 

	
 

	
 
def datetime_to_time(dt):
 
    if dt:
 
        return time.mktime(dt.timetuple())
 

	
 

	
 
def time_to_datetime(tm):
 
    if tm:
 
        if isinstance(tm, basestring):
 
            try:
 
                tm = float(tm)
 
            except ValueError:
 
                return
 
        return datetime.datetime.fromtimestamp(tm)
 

	
 
# Must match regexp in kallithea/public/js/base.js MentionsAutoComplete()
 
# Eat char before @ - it must not look like we are in an email addresses.
 
# Matching is gready so we don't have to look beyond the end.
 
MENTIONS_REGEX = re.compile(r'(?:^|[^a-zA-Z0-9])@([a-zA-Z0-9][-_.a-zA-Z0-9]*[a-zA-Z0-9])')
 

	
 
def extract_mentioned_users(s):
 
    r"""
 
    Returns unique usernames from given string s that have @mention
 

	
 
    :param s: string to get mentions
 

	
 
    >>> extract_mentioned_users('@1-2.a_X,@1234 not@not @ddd@not @n @ee @ff @gg, @gg;@hh @n\n@zz,')
 
    ['1-2.a_X', '1234', 'ddd', 'ee', 'ff', 'gg', 'hh', 'zz']
 
    """
 
    usrs = set()
 
    for username in MENTIONS_REGEX.findall(s):
 
        usrs.add(username)
 

	
 
    return sorted(list(usrs), key=lambda k: k.lower())
 

	
 

	
 
class AttributeDict(dict):
 
    def __getattr__(self, attr):
 
        return self.get(attr, None)
 
    __setattr__ = dict.__setitem__
 
    __delattr__ = dict.__delitem__
 

	
 

	
 
def fix_PATH(os_=None):
 
    """
 
    Get current active python path, and append it to PATH variable to fix issues
 
    of subprocess calls and different python versions
 
    """
 
    if os_ is None:
 
        import os
 
    else:
 
        os = os_
 

	
 
    cur_path = os.path.split(sys.executable)[0]
 
    if not os.environ['PATH'].startswith(cur_path):
 
        os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
 

	
 

	
 
def obfuscate_url_pw(engine):
 
    _url = engine or ''
 
    from sqlalchemy.engine import url as sa_url
 
    from sqlalchemy.exc import ArgumentError
 
    try:
 
        _url = sa_url.make_url(engine)
 
        if _url.password:
 
            _url.password = 'XXXXX'
 
    except Exception:
 
        pass
 
        _url = sa_url.make_url(engine or '')
 
    except ArgumentError:
 
        return engine
 
    if _url.password:
 
        _url.password = 'XXXXX'
 
    return str(_url)
 

	
 

	
 
def get_server_url(environ):
 
    req = webob.Request(environ)
 
    return req.host_url + req.script_name
 

	
 

	
 
def _extract_extras(env=None):
 
    """
 
    Extracts the rc extras data from os.environ, and wraps it into named
 
    AttributeDict object
 
    """
 
    if not env:
 
        env = os.environ
 

	
 
    try:
 
        rc_extras = json.loads(env['KALLITHEA_EXTRAS'])
 
    except Exception:
 
        print os.environ
 
        print >> sys.stderr, traceback.format_exc()
 
    except KeyError:
 
        rc_extras = {}
 

	
 
    try:
 
        for k in ['username', 'repository', 'locked_by', 'scm', 'make_lock',
 
                  'action', 'ip']:
 
            rc_extras[k]
 
    except KeyError, e:
 
        raise Exception('Missing key %s in os.environ %s' % (e, rc_extras))
 

	
 
    return AttributeDict(rc_extras)
 

	
 

	
 
def _set_extras(extras):
 
    # RC_SCM_DATA can probably be removed in the future, but for compatibilty now...
 
    os.environ['KALLITHEA_EXTRAS'] = os.environ['RC_SCM_DATA'] = json.dumps(extras)
 

	
 

	
 
def unique_id(hexlen=32):
 
    alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
 
    return suuid(truncate_to=hexlen, alphabet=alphabet)
 

	
 

	
 
def suuid(url=None, truncate_to=22, alphabet=None):
 
    """
 
    Generate and return a short URL safe UUID.
 

	
 
    If the url parameter is provided, set the namespace to the provided
 
    URL and generate a UUID.
 

	
 
    :param url to get the uuid for
 
    :truncate_to: truncate the basic 22 UUID to shorter version
 

	
 
    The IDs won't be universally unique any longer, but the probability of
 
    a collision will still be very low.
 
    """
 
    # Define our alphabet.
 
    _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
 

	
 
    # If no URL is given, generate a random UUID.
 
    if url is None:
 
        unique_id = uuid.uuid4().int
 
    else:
 
        unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
 

	
 
    alphabet_length = len(_ALPHABET)
 
    output = []
 
    while unique_id > 0:
 
        digit = unique_id % alphabet_length
 
        output.append(_ALPHABET[digit])
 
        unique_id = int(unique_id / alphabet_length)
 
    return "".join(output)[:truncate_to]
 

	
 

	
 
def get_current_authuser():
 
    """
 
    Gets kallithea user from threadlocal tmpl_context variable if it's
 
    defined, else returns None.
 
    """
 
    from pylons import tmpl_context
 
    if hasattr(tmpl_context, 'authuser'):
 
        return tmpl_context.authuser
 

	
 
    return None
 

	
 

	
 
class OptionalAttr(object):
 
    """
 
    Special Optional Option that defines other attribute. Example::
 

	
 
        def test(apiuser, userid=Optional(OAttr('apiuser')):
 
            user = Optional.extract(userid)
 
            # calls
 

	
 
    """
 

	
 
    def __init__(self, attr_name):
 
        self.attr_name = attr_name
 

	
 
    def __repr__(self):
 
        return '<OptionalAttr:%s>' % self.attr_name
 

	
 
    def __call__(self):
 
        return self
 

	
 
#alias
 
OAttr = OptionalAttr
 

	
 

	
 
class Optional(object):
 
    """
 
    Defines an optional parameter::
 

	
 
        param = param.getval() if isinstance(param, Optional) else param
 
        param = param() if isinstance(param, Optional) else param
 

	
 
    is equivalent of::
 

	
 
        param = Optional.extract(param)
 

	
 
    """
 

	
 
    def __init__(self, type_):
 
        self.type_ = type_
 

	
 
    def __repr__(self):
 
        return '<Optional:%s>' % self.type_.__repr__()
 

	
 
    def __call__(self):
 
        return self.getval()
 

	
 
    def getval(self):
 
        """
 
        returns value from this Optional instance
 
        """
 
        if isinstance(self.type_, OAttr):
 
            # use params name
 
            return self.type_.attr_name
 
        return self.type_
 

	
 
    @classmethod
 
    def extract(cls, val):
 
        """
 
        Extracts value from Optional() instance
 

	
 
        :param val:
 
        :return: original value if it's not Optional instance else
 
            value of instance
 
        """
 
        if isinstance(val, cls):
 
            return val.getval()
 
        return val
 

	
 
def urlreadable(s, _cleanstringsub=re.compile('[^-a-zA-Z0-9./]+').sub):
 
    return _cleanstringsub('_', safe_str(s)).rstrip('_')
kallithea/lib/vcs/backends/git/repository.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
"""
 
    vcs.backends.git.repository
 
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
    Git repository implementation.
 

	
 
    :created_on: Apr 8, 2010
 
    :copyright: (c) 2010-2011 by Marcin Kuzminski, Lukasz Balcerzak.
 
"""
 

	
 
import os
 
import re
 
import time
 
import urllib
 
import urllib2
 
import logging
 
import posixpath
 
import string
 
try:
 
    # Python <=2.7
 
    from pipes import quote
 
except ImportError:
 
    # Python 3.3+
 
    from shlex import quote
 

	
 
from dulwich.objects import Tag
 
from dulwich.repo import Repo, NotGitRepository
 
from dulwich.config import ConfigFile
 

	
 
from kallithea.lib.vcs import subprocessio
 
from kallithea.lib.vcs.backends.base import BaseRepository, CollectionGenerator
 
from kallithea.lib.vcs.conf import settings
 

	
 
from kallithea.lib.vcs.exceptions import (
 
    BranchDoesNotExistError, ChangesetDoesNotExistError, EmptyRepositoryError,
 
    RepositoryError, TagAlreadyExistError, TagDoesNotExistError
 
)
 
from kallithea.lib.vcs.utils import safe_unicode, makedate, date_fromtimestamp
 
from kallithea.lib.vcs.utils.lazy import LazyProperty
 
from kallithea.lib.vcs.utils.ordered_dict import OrderedDict
 
from kallithea.lib.vcs.utils.paths import abspath, get_user_home
 

	
 
from kallithea.lib.vcs.utils.hgcompat import (
 
    hg_url, httpbasicauthhandler, httpdigestauthhandler
 
)
 

	
 
from .changeset import GitChangeset
 
from .inmemory import GitInMemoryChangeset
 
from .workdir import GitWorkdir
 

	
 
SHA_PATTERN = re.compile(r'^[[0-9a-fA-F]{12}|[0-9a-fA-F]{40}]$')
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class GitRepository(BaseRepository):
 
    """
 
    Git repository backend.
 
    """
 
    DEFAULT_BRANCH_NAME = 'master'
 
    scm = 'git'
 

	
 
    def __init__(self, repo_path, create=False, src_url=None,
 
                 update_after_clone=False, bare=False):
 

	
 
        self.path = abspath(repo_path)
 
        repo = self._get_repo(create, src_url, update_after_clone, bare)
 
        self.bare = repo.bare
 

	
 
    @property
 
    def _config_files(self):
 
        return [
 
            self.bare and abspath(self.path, 'config')
 
                      or abspath(self.path, '.git', 'config'),
 
             abspath(get_user_home(), '.gitconfig'),
 
         ]
 

	
 
    @property
 
    def _repo(self):
 
        return Repo(self.path)
 

	
 
    @property
 
    def head(self):
 
        try:
 
            return self._repo.head()
 
        except KeyError:
 
            return None
 

	
 
    @property
 
    def _empty(self):
 
        """
 
        Checks if repository is empty ie. without any changesets
 
        """
 

	
 
        try:
 
            self.revisions[0]
 
        except (KeyError, IndexError):
 
            return True
 
        return False
 

	
 
    @LazyProperty
 
    def revisions(self):
 
        """
 
        Returns list of revisions' ids, in ascending order.  Being lazy
 
        attribute allows external tools to inject shas from cache.
 
        """
 
        return self._get_all_revisions()
 

	
 
    @classmethod
 
    def _run_git_command(cls, cmd, **opts):
 
        """
 
        Runs given ``cmd`` as git command and returns tuple
 
        (stdout, stderr).
 

	
 
        :param cmd: git command to be executed
 
        :param opts: env options to pass into Subprocess command
 
        """
 

	
 
        if '_bare' in opts:
 
            _copts = []
 
            del opts['_bare']
 
        else:
 
            _copts = ['-c', 'core.quotepath=false', ]
 
        safe_call = False
 
        if '_safe' in opts:
 
            #no exc on failure
 
            del opts['_safe']
 
            safe_call = True
 

	
 
        _str_cmd = False
 
        if isinstance(cmd, basestring):
 
            cmd = [cmd]
 
            _str_cmd = True
 

	
 
        gitenv = os.environ
 
        # need to clean fix GIT_DIR !
 
        if 'GIT_DIR' in gitenv:
 
            del gitenv['GIT_DIR']
 
        gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
 

	
 
        _git_path = settings.GIT_EXECUTABLE_PATH
 
        cmd = [_git_path] + _copts + cmd
 
        if _str_cmd:
 
            cmd = ' '.join(cmd)
 

	
 
        try:
 
            _opts = dict(
 
                env=gitenv,
 
                shell=True,
 
            )
 
            _opts.update(opts)
 
            p = subprocessio.SubprocessIOChunker(cmd, **_opts)
 
        except (EnvironmentError, OSError), err:
 
            tb_err = ("Couldn't run git command (%s).\n"
 
                      "Original error was:%s\n" % (cmd, err))
 
            log.error(tb_err)
 
            if safe_call:
 
                return '', err
 
            else:
 
                raise RepositoryError(tb_err)
 

	
 
        return ''.join(p.output), ''.join(p.error)
 

	
 
    def run_git_command(self, cmd):
 
        opts = {}
 
        if os.path.isdir(self.path):
 
            opts['cwd'] = self.path
 
        return self._run_git_command(cmd, **opts)
 

	
 
    @classmethod
 
    def _check_url(cls, url):
 
        """
 
        Function will check given url and try to verify if it's a valid
 
        link. Sometimes it may happened that git will issue basic
 
        auth request that can cause whole API to hang when used from python
 
        or other external calls.
 

	
 
        On failures it'll raise urllib2.HTTPError, exception is also thrown
 
        when the return code is non 200
 
        """
 

	
 
        # check first if it's not an local url
 
        if os.path.isdir(url) or url.startswith('file:'):
 
            return True
 

	
 
        if '+' in url[:url.find('://')]:
 
            url = url[url.find('+') + 1:]
 

	
 
        handlers = []
 
        url_obj = hg_url(url)
 
        test_uri, authinfo = url_obj.authinfo()
 
        url_obj.passwd = '*****'
 
        cleaned_uri = str(url_obj)
 

	
 
        if not test_uri.endswith('info/refs'):
 
            test_uri = test_uri.rstrip('/') + '/info/refs'
 

	
 
        if authinfo:
 
            #create a password manager
 
            passmgr = urllib2.HTTPPasswordMgrWithDefaultRealm()
 
            passmgr.add_password(*authinfo)
 

	
 
            handlers.extend((httpbasicauthhandler(passmgr),
 
                             httpdigestauthhandler(passmgr)))
 

	
 
        o = urllib2.build_opener(*handlers)
 
        o.addheaders = [('User-Agent', 'git/1.7.8.0')]  # fake some git
 

	
 
        q = {"service": 'git-upload-pack'}
 
        qs = '?%s' % urllib.urlencode(q)
 
        cu = "%s%s" % (test_uri, qs)
 
        req = urllib2.Request(cu, None, {})
 

	
 
        try:
 
            resp = o.open(req)
 
            if resp.code != 200:
 
                raise Exception('Return Code is not 200')
 
        except Exception, e:
 
            # means it cannot be cloned
 
            raise urllib2.URLError("[%s] org_exc: %s" % (cleaned_uri, e))
 

	
 
        # now detect if it's proper git repo
 
        gitdata = resp.read()
 
        if not 'service=git-upload-pack' in gitdata:
 
            raise urllib2.URLError(
 
                "url [%s] does not look like an git" % (cleaned_uri))
 

	
 
        return True
 

	
 
    def _get_repo(self, create, src_url=None, update_after_clone=False,
 
                  bare=False):
 
        if create and os.path.exists(self.path):
 
            raise RepositoryError("Location already exist")
 
        if src_url and not create:
 
            raise RepositoryError("Create should be set to True if src_url is "
 
                                  "given (clone operation creates repository)")
 
        try:
 
            if create and src_url:
 
                GitRepository._check_url(src_url)
 
                self.clone(src_url, update_after_clone, bare)
 
                return Repo(self.path)
 
            elif create:
 
                os.makedirs(self.path)
 
                if bare:
 
                    return Repo.init_bare(self.path)
 
                else:
 
                    return Repo.init(self.path)
 
            else:
 
                return self._repo
 
        except (NotGitRepository, OSError), err:
 
            raise RepositoryError(err)
 

	
 
    def _get_all_revisions(self):
 
        # we must check if this repo is not empty, since later command
 
        # fails if it is. And it's cheaper to ask than throw the subprocess
 
        # errors
 
        try:
 
            self._repo.head()
 
        except KeyError:
 
            return []
 

	
 
        rev_filter = settings.GIT_REV_FILTER
 
        cmd = 'rev-list %s --reverse --date-order' % (rev_filter)
 
        try:
 
            so, se = self.run_git_command(cmd)
 
        except RepositoryError:
 
            # Can be raised for empty repositories
 
            return []
 
        return so.splitlines()
 

	
 
    def _get_all_revisions2(self):
 
        #alternate implementation using dulwich
 
        includes = [x[1][0] for x in self._parsed_refs.iteritems()
 
                    if x[1][1] != 'T']
 
        return [c.commit.id for c in self._repo.get_walker(include=includes)]
 

	
 
    def _get_revision(self, revision):
 
        """
 
        For git backend we always return integer here. This way we ensure
 
        that changset's revision attribute would become integer.
 
        """
 

	
 
        is_null = lambda o: len(o) == revision.count('0')
 

	
 
        if self._empty:
 
            raise EmptyRepositoryError("There are no changesets yet")
 

	
 
        if revision in (None, '', 'tip', 'HEAD', 'head', -1):
 
            return self.revisions[-1]
 

	
 
        is_bstr = isinstance(revision, (str, unicode))
 
        if ((is_bstr and revision.isdigit() and len(revision) < 12)
 
            or isinstance(revision, int) or is_null(revision)):
 
            try:
 
                revision = self.revisions[int(revision)]
 
            except Exception:
 
            except IndexError:
 
                msg = ("Revision %s does not exist for %s" % (revision, self))
 
                raise ChangesetDoesNotExistError(msg)
 

	
 
        elif is_bstr:
 
            # get by branch/tag name
 
            _ref_revision = self._parsed_refs.get(revision)
 
            if _ref_revision:  # and _ref_revision[1] in ['H', 'RH', 'T']:
 
                return _ref_revision[0]
 

	
 
            _tags_shas = self.tags.values()
 
            # maybe it's a tag ? we don't have them in self.revisions
 
            if revision in _tags_shas:
 
                return _tags_shas[_tags_shas.index(revision)]
 

	
 
            elif not SHA_PATTERN.match(revision) or revision not in self.revisions:
 
                msg = ("Revision %s does not exist for %s" % (revision, self))
 
                raise ChangesetDoesNotExistError(msg)
 

	
 
        # Ensure we return full id
 
        if not SHA_PATTERN.match(str(revision)):
 
            raise ChangesetDoesNotExistError("Given revision %s not recognized"
 
                % revision)
 
        return revision
 

	
 
    def get_ref_revision(self, ref_type, ref_name):
 
        """
 
        Returns ``MercurialChangeset`` object representing repository's
 
        changeset at the given ``revision``.
 
        """
 
        return self._get_revision(ref_name)
 

	
 
    def _get_archives(self, archive_name='tip'):
 

	
 
        for i in [('zip', '.zip'), ('gz', '.tar.gz'), ('bz2', '.tar.bz2')]:
 
                yield {"type": i[0], "extension": i[1], "node": archive_name}
 

	
 
    def _get_url(self, url):
 
        """
 
        Returns normalized url. If schema is not given, would fall to
 
        filesystem (``file:///``) schema.
 
        """
 
        url = str(url)
 
        if url != 'default' and not '://' in url:
 
            url = ':///'.join(('file', url))
 
        return url
 

	
 
    def get_hook_location(self):
 
        """
 
        returns absolute path to location where hooks are stored
 
        """
 
        loc = os.path.join(self.path, 'hooks')
 
        if not self.bare:
 
            loc = os.path.join(self.path, '.git', 'hooks')
 
        return loc
 

	
 
    @LazyProperty
 
    def name(self):
 
        return os.path.basename(self.path)
 

	
 
    @LazyProperty
 
    def last_change(self):
 
        """
 
        Returns last change made on this repository as datetime object
 
        """
 
        return date_fromtimestamp(self._get_mtime(), makedate()[1])
 

	
 
    def _get_mtime(self):
 
        try:
 
            return time.mktime(self.get_changeset().date.timetuple())
 
        except RepositoryError:
 
            idx_loc = '' if self.bare else '.git'
 
            # fallback to filesystem
 
            in_path = os.path.join(self.path, idx_loc, "index")
 
            he_path = os.path.join(self.path, idx_loc, "HEAD")
 
            if os.path.exists(in_path):
 
                return os.stat(in_path).st_mtime
 
            else:
 
                return os.stat(he_path).st_mtime
 

	
 
    @LazyProperty
 
    def description(self):
 
        undefined_description = u'unknown'
 
        _desc = self._repo.get_description()
 
        return safe_unicode(_desc or undefined_description)
 

	
 
    @LazyProperty
 
    def contact(self):
 
        undefined_contact = u'Unknown'
 
        return undefined_contact
 

	
 
    @property
 
    def branches(self):
 
        if not self.revisions:
 
            return {}
 
        sortkey = lambda ctx: ctx[0]
 
        _branches = [(x[0], x[1][0])
 
                     for x in self._parsed_refs.iteritems() if x[1][1] == 'H']
 
        return OrderedDict(sorted(_branches, key=sortkey, reverse=False))
 

	
 
    @LazyProperty
 
    def closed_branches(self):
 
        return {}
 

	
 
    @LazyProperty
 
    def tags(self):
 
        return self._get_tags()
 

	
 
    def _get_tags(self):
 
        if not self.revisions:
 
            return {}
 

	
 
        sortkey = lambda ctx: ctx[0]
 
        _tags = [(x[0], x[1][0])
 
                 for x in self._parsed_refs.iteritems() if x[1][1] == 'T']
 
        return OrderedDict(sorted(_tags, key=sortkey, reverse=True))
 

	
 
    def tag(self, name, user, revision=None, message=None, date=None,
 
            **kwargs):
 
        """
 
        Creates and returns a tag for the given ``revision``.
 

	
 
        :param name: name for new tag
 
        :param user: full username, i.e.: "Joe Doe <joe.doe@example.com>"
 
        :param revision: changeset id for which new tag would be created
 
        :param message: message of the tag's commit
 
        :param date: date of tag's commit
 

	
 
        :raises TagAlreadyExistError: if tag with same name already exists
 
        """
 
        if name in self.tags:
 
            raise TagAlreadyExistError("Tag %s already exists" % name)
 
        changeset = self.get_changeset(revision)
 
        message = message or "Added tag %s for commit %s" % (name,
 
            changeset.raw_id)
 
        self._repo.refs["refs/tags/%s" % name] = changeset._commit.id
 

	
 
        self._parsed_refs = self._get_parsed_refs()
 
        self.tags = self._get_tags()
 
        return changeset
 

	
 
    def remove_tag(self, name, user, message=None, date=None):
 
        """
 
        Removes tag with the given ``name``.
 

	
 
        :param name: name of the tag to be removed
 
        :param user: full username, i.e.: "Joe Doe <joe.doe@example.com>"
 
        :param message: message of the tag's removal commit
 
        :param date: date of tag's removal commit
 

	
 
        :raises TagDoesNotExistError: if tag with given name does not exists
 
        """
 
        if name not in self.tags:
 
            raise TagDoesNotExistError("Tag %s does not exist" % name)
 
        tagpath = posixpath.join(self._repo.refs.path, 'refs', 'tags', name)
 
        try:
 
            os.remove(tagpath)
 
            self._parsed_refs = self._get_parsed_refs()
 
            self.tags = self._get_tags()
 
        except OSError, e:
 
            raise RepositoryError(e.strerror)
 

	
 
    @LazyProperty
 
    def bookmarks(self):
 
        """
 
        Gets bookmarks for this repository
 
        """
 
        return {}
 

	
 
    @LazyProperty
 
    def _parsed_refs(self):
 
        return self._get_parsed_refs()
 

	
 
    def _get_parsed_refs(self):
 
        # cache the property
 
        _repo = self._repo
 
        refs = _repo.get_refs()
 
        keys = [('refs/heads/', 'H'),
 
                ('refs/remotes/origin/', 'RH'),
 
                ('refs/tags/', 'T')]
 
        _refs = {}
 
        for ref, sha in refs.iteritems():
 
            for k, type_ in keys:
 
                if ref.startswith(k):
 
                    _key = ref[len(k):]
 
                    if type_ == 'T':
 
                        obj = _repo.get_object(sha)
 
                        if isinstance(obj, Tag):
 
                            sha = _repo.get_object(sha).object[1]
 
                    _refs[_key] = [sha, type_]
 
                    break
 
        return _refs
 

	
 
    def _heads(self, reverse=False):
 
        refs = self._repo.get_refs()
 
        heads = {}
 

	
 
        for key, val in refs.items():
 
            for ref_key in ['refs/heads/', 'refs/remotes/origin/']:
 
                if key.startswith(ref_key):
 
                    n = key[len(ref_key):]
 
                    if n not in ['HEAD']:
 
                        heads[n] = val
 

	
 
        return heads if reverse else dict((y, x) for x, y in heads.iteritems())
 

	
 
    def get_changeset(self, revision=None):
 
        """
 
        Returns ``GitChangeset`` object representing commit from git repository
 
        at the given revision or head (most recent commit) if None given.
 
        """
 
        if isinstance(revision, GitChangeset):
 
            return revision
 
        revision = self._get_revision(revision)
 
        changeset = GitChangeset(repository=self, revision=revision)
 
        return changeset
 

	
 
    def get_changesets(self, start=None, end=None, start_date=None,
 
           end_date=None, branch_name=None, reverse=False):
 
        """
 
        Returns iterator of ``GitChangeset`` objects from start to end (both
 
        are inclusive), in ascending date order (unless ``reverse`` is set).
 

	
 
        :param start: changeset ID, as str; first returned changeset
 
        :param end: changeset ID, as str; last returned changeset
 
        :param start_date: if specified, changesets with commit date less than
 
          ``start_date`` would be filtered out from returned set
 
        :param end_date: if specified, changesets with commit date greater than
 
          ``end_date`` would be filtered out from returned set
 
        :param branch_name: if specified, changesets not reachable from given
 
          branch would be filtered out from returned set
 
        :param reverse: if ``True``, returned generator would be reversed
 
          (meaning that returned changesets would have descending date order)
 

	
 
        :raise BranchDoesNotExistError: If given ``branch_name`` does not
 
            exist.
 
        :raise ChangesetDoesNotExistError: If changeset for given ``start`` or
 
          ``end`` could not be found.
 

	
 
        """
 
        if branch_name and branch_name not in self.branches:
 
            raise BranchDoesNotExistError("Branch '%s' not found" \
 
                                          % branch_name)
 
        # actually we should check now if it's not an empty repo to not spaw
 
        # subprocess commands
 
        if self._empty:
 
            raise EmptyRepositoryError("There are no changesets yet")
 

	
 
        # %H at format means (full) commit hash, initial hashes are retrieved
 
        # in ascending date order
 
        cmd_template = 'log --date-order --reverse --pretty=format:"%H"'
 
        cmd_params = {}
 
        if start_date:
 
            cmd_template += ' --since "$since"'
 
            cmd_params['since'] = start_date.strftime('%m/%d/%y %H:%M:%S')
 
        if end_date:
 
            cmd_template += ' --until "$until"'
 
            cmd_params['until'] = end_date.strftime('%m/%d/%y %H:%M:%S')
 
        if branch_name:
 
            cmd_template += ' $branch_name'
 
            cmd_params['branch_name'] = branch_name
 
        else:
 
            rev_filter = settings.GIT_REV_FILTER
 
            cmd_template += ' %s' % (rev_filter)
 

	
 
        cmd = string.Template(cmd_template).safe_substitute(**cmd_params)
 
        revs = self.run_git_command(cmd)[0].splitlines()
 
        start_pos = 0
 
        end_pos = len(revs)
 
        if start:
 
            _start = self._get_revision(start)
 
            try:
 
                start_pos = revs.index(_start)
 
            except ValueError:
 
                pass
 

	
 
        if end is not None:
 
            _end = self._get_revision(end)
 
            try:
 
                end_pos = revs.index(_end)
 
            except ValueError:
 
                pass
 

	
 
        if None not in [start, end] and start_pos > end_pos:
 
            raise RepositoryError('start cannot be after end')
 

	
 
        if end_pos is not None:
 
            end_pos += 1
 

	
 
        revs = revs[start_pos:end_pos]
 
        if reverse:
 
            revs = reversed(revs)
 
        return CollectionGenerator(self, revs)
 

	
 
    def get_diff(self, rev1, rev2, path=None, ignore_whitespace=False,
 
                 context=3):
 
        """
 
        Returns (git like) *diff*, as plain text. Shows changes introduced by
 
        ``rev2`` since ``rev1``.
 

	
 
        :param rev1: Entry point from which diff is shown. Can be
 
          ``self.EMPTY_CHANGESET`` - in this case, patch showing all
 
          the changes since empty state of the repository until ``rev2``
 
        :param rev2: Until which revision changes should be shown.
 
        :param ignore_whitespace: If set to ``True``, would not show whitespace
 
          changes. Defaults to ``False``.
 
        :param context: How many lines before/after changed lines should be
 
          shown. Defaults to ``3``.
 
        """
 
        flags = ['-U%s' % context, '--full-index', '--binary', '-p', '-M', '--abbrev=40']
 
        if ignore_whitespace:
 
            flags.append('-w')
 

	
 
        if hasattr(rev1, 'raw_id'):
 
            rev1 = getattr(rev1, 'raw_id')
 

	
 
        if hasattr(rev2, 'raw_id'):
 
            rev2 = getattr(rev2, 'raw_id')
 

	
 
        if rev1 == self.EMPTY_CHANGESET:
 
            rev2 = self.get_changeset(rev2).raw_id
 
            cmd = ' '.join(['show'] + flags + [rev2])
 
        else:
 
            rev1 = self.get_changeset(rev1).raw_id
 
            rev2 = self.get_changeset(rev2).raw_id
 
            cmd = ' '.join(['diff'] + flags + [rev1, rev2])
 

	
 
        if path:
 
            cmd += ' -- "%s"' % path
 

	
 
        stdout, stderr = self.run_git_command(cmd)
 
        # If we used 'show' command, strip first few lines (until actual diff
 
        # starts)
 
        if rev1 == self.EMPTY_CHANGESET:
 
            lines = stdout.splitlines()
 
            x = 0
 
            for line in lines:
 
                if line.startswith('diff'):
 
                    break
 
                x += 1
 
            # Append new line just like 'diff' command do
 
            stdout = '\n'.join(lines[x:]) + '\n'
 
        return stdout
 

	
 
    @LazyProperty
 
    def in_memory_changeset(self):
 
        """
 
        Returns ``GitInMemoryChangeset`` object for this repository.
 
        """
 
        return GitInMemoryChangeset(self)
 

	
 
    def clone(self, url, update_after_clone=True, bare=False):
 
        """
 
        Tries to clone changes from external location.
 

	
 
        :param update_after_clone: If set to ``False``, git won't checkout
 
          working directory
 
        :param bare: If set to ``True``, repository would be cloned into
 
          *bare* git repository (no working directory at all).
 
        """
 
        url = self._get_url(url)
 
        cmd = ['clone', '-q']
 
        if bare:
 
            cmd.append('--bare')
 
        elif not update_after_clone:
 
            cmd.append('--no-checkout')
 
        cmd += ['--', quote(url), quote(self.path)]
 
        cmd = ' '.join(cmd)
 
        # If error occurs run_git_command raises RepositoryError already
 
        self.run_git_command(cmd)
 

	
 
    def pull(self, url):
 
        """
 
        Tries to pull changes from external location.
 
        """
 
        url = self._get_url(url)
 
        cmd = ['pull', "--ff-only", quote(url)]
 
        cmd = ' '.join(cmd)
 
        # If error occurs run_git_command raises RepositoryError already
 
        self.run_git_command(cmd)
 

	
 
    def fetch(self, url):
 
        """
 
        Tries to pull changes from external location.
 
        """
 
        url = self._get_url(url)
 
        so, se = self.run_git_command('ls-remote -h %s' % quote(url))
 
        refs = []
 
        for line in (x for x in so.splitlines()):
 
            sha, ref = line.split('\t')
 
            refs.append(ref)
 
        refs = ' '.join(('+%s:%s' % (r, r) for r in refs))
 
        cmd = '''fetch %s -- %s''' % (quote(url), refs)
 
        self.run_git_command(cmd)
 

	
 
    def _update_server_info(self):
 
        """
 
        runs gits update-server-info command in this repo instance
 
        """
 
        from dulwich.server import update_server_info
 
        update_server_info(self._repo)
 

	
 
    @LazyProperty
 
    def workdir(self):
 
        """
 
        Returns ``Workdir`` instance for this repository.
 
        """
 
        return GitWorkdir(self)
 

	
 
    def get_config_value(self, section, name, config_file=None):
 
        """
 
        Returns configuration value for a given [``section``] and ``name``.
 

	
 
        :param section: Section we want to retrieve value from
 
        :param name: Name of configuration we want to retrieve
 
        :param config_file: A path to file which should be used to retrieve
 
          configuration from (might also be a list of file paths)
 
        """
 
        if config_file is None:
 
            config_file = []
 
        elif isinstance(config_file, basestring):
 
            config_file = [config_file]
 

	
 
        def gen_configs():
 
            for path in config_file + self._config_files:
 
                try:
 
                    yield ConfigFile.from_path(path)
 
                except (IOError, OSError, ValueError):
 
                    continue
 

	
 
        for config in gen_configs():
 
            try:
 
                return config.get(section, name)
 
            except KeyError:
 
                continue
 
        return None
 

	
 
    def get_user_name(self, config_file=None):
 
        """
 
        Returns user's name from global configuration file.
 

	
 
        :param config_file: A path to file which should be used to retrieve
 
          configuration from (might also be a list of file paths)
 
        """
 
        return self.get_config_value('user', 'name', config_file)
 

	
 
    def get_user_email(self, config_file=None):
 
        """
 
        Returns user's email from global configuration file.
 

	
 
        :param config_file: A path to file which should be used to retrieve
 
          configuration from (might also be a list of file paths)
 
        """
 
        return self.get_config_value('user', 'email', config_file)
kallithea/lib/vcs/backends/hg/repository.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
"""
 
    vcs.backends.hg.repository
 
    ~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
    Mercurial repository implementation.
 

	
 
    :created_on: Apr 8, 2010
 
    :copyright: (c) 2010-2011 by Marcin Kuzminski, Lukasz Balcerzak.
 
"""
 

	
 
import os
 
import time
 
import urllib
 
import urllib2
 
import logging
 
import datetime
 

	
 
from kallithea.lib.vcs.backends.base import BaseRepository, CollectionGenerator
 

	
 
from kallithea.lib.vcs.exceptions import (
 
    BranchDoesNotExistError, ChangesetDoesNotExistError, EmptyRepositoryError,
 
    RepositoryError, VCSError, TagAlreadyExistError, TagDoesNotExistError
 
)
 
from kallithea.lib.vcs.utils import (
 
    author_email, author_name, date_fromtimestamp, makedate, safe_unicode, safe_str,
 
)
 
from kallithea.lib.vcs.utils.lazy import LazyProperty
 
from kallithea.lib.vcs.utils.ordered_dict import OrderedDict
 
from kallithea.lib.vcs.utils.paths import abspath
 
from kallithea.lib.vcs.utils.hgcompat import (
 
    ui, nullid, match, patch, diffopts, clone, get_contact, pull,
 
    localrepository, RepoLookupError, Abort, RepoError, hex, scmutil, hg_url,
 
    httpbasicauthhandler, httpdigestauthhandler, peer, httppeer
 
)
 

	
 
from .changeset import MercurialChangeset
 
from .inmemory import MercurialInMemoryChangeset
 
from .workdir import MercurialWorkdir
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class MercurialRepository(BaseRepository):
 
    """
 
    Mercurial repository backend
 
    """
 
    DEFAULT_BRANCH_NAME = 'default'
 
    scm = 'hg'
 

	
 
    def __init__(self, repo_path, create=False, baseui=None, src_url=None,
 
                 update_after_clone=False):
 
        """
 
        Raises RepositoryError if repository could not be find at the given
 
        ``repo_path``.
 

	
 
        :param repo_path: local path of the repository
 
        :param create=False: if set to True, would try to create repository if
 
           it does not exist rather than raising exception
 
        :param baseui=None: user data
 
        :param src_url=None: would try to clone repository from given location
 
        :param update_after_clone=False: sets update of working copy after
 
          making a clone
 
        """
 

	
 
        if not isinstance(repo_path, str):
 
            raise VCSError('Mercurial backend requires repository path to '
 
                           'be instance of <str> got %s instead' %
 
                           type(repo_path))
 

	
 
        self.path = abspath(repo_path)
 
        self.baseui = baseui or ui.ui()
 
        # We've set path and ui, now we can set _repo itself
 
        self._repo = self._get_repo(create, src_url, update_after_clone)
 

	
 
    @property
 
    def _empty(self):
 
        """
 
        Checks if repository is empty ie. without any changesets
 
        """
 
        # TODO: Following raises errors when using InMemoryChangeset...
 
        # return len(self._repo.changelog) == 0
 
        return len(self.revisions) == 0
 

	
 
    @LazyProperty
 
    def revisions(self):
 
        """
 
        Returns list of revisions' ids, in ascending order.  Being lazy
 
        attribute allows external tools to inject shas from cache.
 
        """
 
        return self._get_all_revisions()
 

	
 
    @LazyProperty
 
    def name(self):
 
        return os.path.basename(self.path)
 

	
 
    @LazyProperty
 
    def branches(self):
 
        return self._get_branches()
 

	
 
    @LazyProperty
 
    def closed_branches(self):
 
        return self._get_branches(normal=False, closed=True)
 

	
 
    @LazyProperty
 
    def allbranches(self):
 
        """
 
        List all branches, including closed branches.
 
        """
 
        return self._get_branches(closed=True)
 

	
 
    def _get_branches(self, normal=True, closed=False):
 
        """
 
        Gets branches for this repository
 
        Returns only not closed branches by default
 

	
 
        :param closed: return also closed branches for mercurial
 
        :param normal: return also normal branches
 
        """
 

	
 
        if self._empty:
 
            return {}
 

	
 
        bt = OrderedDict()
 
        for bn, _heads, tip, isclosed in sorted(self._repo.branchmap().iterbranches()):
 
            if isclosed:
 
                if closed:
 
                    bt[safe_unicode(bn)] = hex(tip)
 
            else:
 
                if normal:
 
                    bt[safe_unicode(bn)] = hex(tip)
 

	
 
        return bt
 

	
 
    @LazyProperty
 
    def tags(self):
 
        """
 
        Gets tags for this repository
 
        """
 
        return self._get_tags()
 

	
 
    def _get_tags(self):
 
        if self._empty:
 
            return {}
 

	
 
        sortkey = lambda ctx: ctx[0]  # sort by name
 
        _tags = [(safe_unicode(n), hex(h),) for n, h in
 
                 self._repo.tags().items()]
 

	
 
        return OrderedDict(sorted(_tags, key=sortkey, reverse=True))
 

	
 
    def tag(self, name, user, revision=None, message=None, date=None,
 
            **kwargs):
 
        """
 
        Creates and returns a tag for the given ``revision``.
 

	
 
        :param name: name for new tag
 
        :param user: full username, i.e.: "Joe Doe <joe.doe@example.com>"
 
        :param revision: changeset id for which new tag would be created
 
        :param message: message of the tag's commit
 
        :param date: date of tag's commit
 

	
 
        :raises TagAlreadyExistError: if tag with same name already exists
 
        """
 
        if name in self.tags:
 
            raise TagAlreadyExistError("Tag %s already exists" % name)
 
        changeset = self.get_changeset(revision)
 
        local = kwargs.setdefault('local', False)
 

	
 
        if message is None:
 
            message = "Added tag %s for changeset %s" % (name,
 
                changeset.short_id)
 

	
 
        if date is None:
 
            date = datetime.datetime.now().ctime()
 

	
 
        try:
 
            self._repo.tag(name, changeset._ctx.node(), message, local, user,
 
                date)
 
        except Abort, e:
 
            raise RepositoryError(e.message)
 

	
 
        # Reinitialize tags
 
        self.tags = self._get_tags()
 
        tag_id = self.tags[name]
 

	
 
        return self.get_changeset(revision=tag_id)
 

	
 
    def remove_tag(self, name, user, message=None, date=None):
 
        """
 
        Removes tag with the given ``name``.
 

	
 
        :param name: name of the tag to be removed
 
        :param user: full username, i.e.: "Joe Doe <joe.doe@example.com>"
 
        :param message: message of the tag's removal commit
 
        :param date: date of tag's removal commit
 

	
 
        :raises TagDoesNotExistError: if tag with given name does not exists
 
        """
 
        if name not in self.tags:
 
            raise TagDoesNotExistError("Tag %s does not exist" % name)
 
        if message is None:
 
            message = "Removed tag %s" % name
 
        if date is None:
 
            date = datetime.datetime.now().ctime()
 
        local = False
 

	
 
        try:
 
            self._repo.tag(name, nullid, message, local, user, date)
 
            self.tags = self._get_tags()
 
        except Abort, e:
 
            raise RepositoryError(e.message)
 

	
 
    @LazyProperty
 
    def bookmarks(self):
 
        """
 
        Gets bookmarks for this repository
 
        """
 
        return self._get_bookmarks()
 

	
 
    def _get_bookmarks(self):
 
        if self._empty:
 
            return {}
 

	
 
        sortkey = lambda ctx: ctx[0]  # sort by name
 
        _bookmarks = [(safe_unicode(n), hex(h),) for n, h in
 
                 self._repo._bookmarks.items()]
 
        return OrderedDict(sorted(_bookmarks, key=sortkey, reverse=True))
 

	
 
    def _get_all_revisions(self):
 

	
 
        return map(lambda x: hex(x[7]), self._repo.changelog.index)[:-1]
 

	
 
    def get_diff(self, rev1, rev2, path='', ignore_whitespace=False,
 
                  context=3):
 
        """
 
        Returns (git like) *diff*, as plain text. Shows changes introduced by
 
        ``rev2`` since ``rev1``.
 

	
 
        :param rev1: Entry point from which diff is shown. Can be
 
          ``self.EMPTY_CHANGESET`` - in this case, patch showing all
 
          the changes since empty state of the repository until ``rev2``
 
        :param rev2: Until which revision changes should be shown.
 
        :param ignore_whitespace: If set to ``True``, would not show whitespace
 
          changes. Defaults to ``False``.
 
        :param context: How many lines before/after changed lines should be
 
          shown. Defaults to ``3``.
 
        """
 
        if hasattr(rev1, 'raw_id'):
 
            rev1 = getattr(rev1, 'raw_id')
 

	
 
        if hasattr(rev2, 'raw_id'):
 
            rev2 = getattr(rev2, 'raw_id')
 

	
 
        # Check if given revisions are present at repository (may raise
 
        # ChangesetDoesNotExistError)
 
        if rev1 != self.EMPTY_CHANGESET:
 
            self.get_changeset(rev1)
 
        self.get_changeset(rev2)
 
        if path:
 
            file_filter = match(self.path, '', [path])
 
        else:
 
            file_filter = None
 

	
 
        return ''.join(patch.diff(self._repo, rev1, rev2, match=file_filter,
 
                          opts=diffopts(git=True,
 
                                        ignorews=ignore_whitespace,
 
                                        context=context)))
 

	
 
    @classmethod
 
    def _check_url(cls, url, repoui=None):
 
        """
 
        Function will check given url and try to verify if it's a valid
 
        link. Sometimes it may happened that mercurial will issue basic
 
        auth request that can cause whole API to hang when used from python
 
        or other external calls.
 

	
 
        On failures it'll raise urllib2.HTTPError, exception is also thrown
 
        when the return code is non 200
 
        """
 
        # check first if it's not an local url
 
        if os.path.isdir(url) or url.startswith('file:'):
 
            return True
 

	
 
        if '+' in url[:url.find('://')]:
 
            url = url[url.find('+') + 1:]
 

	
 
        handlers = []
 
        url_obj = hg_url(url)
 
        test_uri, authinfo = url_obj.authinfo()
 
        url_obj.passwd = '*****'
 
        cleaned_uri = str(url_obj)
 

	
 
        if authinfo:
 
            #create a password manager
 
            passmgr = urllib2.HTTPPasswordMgrWithDefaultRealm()
 
            passmgr.add_password(*authinfo)
 

	
 
            handlers.extend((httpbasicauthhandler(passmgr),
 
                             httpdigestauthhandler(passmgr)))
 

	
 
        o = urllib2.build_opener(*handlers)
 
        o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
 
                        ('Accept', 'application/mercurial-0.1')]
 

	
 
        q = {"cmd": 'between'}
 
        q.update({'pairs': "%s-%s" % ('0' * 40, '0' * 40)})
 
        qs = '?%s' % urllib.urlencode(q)
 
        cu = "%s%s" % (test_uri, qs)
 
        req = urllib2.Request(cu, None, {})
 

	
 
        try:
 
            resp = o.open(req)
 
            if resp.code != 200:
 
                raise Exception('Return Code is not 200')
 
        except Exception, e:
 
            # means it cannot be cloned
 
            raise urllib2.URLError("[%s] org_exc: %s" % (cleaned_uri, e))
 

	
 
        # now check if it's a proper hg repo
 
        try:
 
            httppeer(repoui or ui.ui(), url).lookup('tip')
 
        except Exception, e:
 
            raise urllib2.URLError(
 
                "url [%s] does not look like an hg repo org_exc: %s"
 
                % (cleaned_uri, e))
 

	
 
        return True
 

	
 
    def _get_repo(self, create, src_url=None, update_after_clone=False):
 
        """
 
        Function will check for mercurial repository in given path and return
 
        a localrepo object. If there is no repository in that path it will
 
        raise an exception unless ``create`` parameter is set to True - in
 
        that case repository would be created and returned.
 
        If ``src_url`` is given, would try to clone repository from the
 
        location at given clone_point. Additionally it'll make update to
 
        working copy accordingly to ``update_after_clone`` flag
 
        """
 

	
 
        try:
 
            if src_url:
 
                url = str(self._get_url(src_url))
 
                opts = {}
 
                if not update_after_clone:
 
                    opts.update({'noupdate': True})
 
                try:
 
                    MercurialRepository._check_url(url, self.baseui)
 
                    clone(self.baseui, url, self.path, **opts)
 
#                except urllib2.URLError:
 
#                    raise Abort("Got HTTP 404 error")
 
                except Exception:
 
                    raise
 
                MercurialRepository._check_url(url, self.baseui)
 
                clone(self.baseui, url, self.path, **opts)
 

	
 
                # Don't try to create if we've already cloned repo
 
                create = False
 
            return localrepository(self.baseui, self.path, create=create)
 
        except (Abort, RepoError), err:
 
            if create:
 
                msg = "Cannot create repository at %s. Original error was %s"\
 
                    % (self.path, err)
 
            else:
 
                msg = "Not valid repository at %s. Original error was %s"\
 
                    % (self.path, err)
 
            raise RepositoryError(msg)
 

	
 
    @LazyProperty
 
    def in_memory_changeset(self):
 
        return MercurialInMemoryChangeset(self)
 

	
 
    @LazyProperty
 
    def description(self):
 
        undefined_description = u'unknown'
 
        _desc = self._repo.ui.config('web', 'description', None, untrusted=True)
 
        return safe_unicode(_desc or undefined_description)
 

	
 
    @LazyProperty
 
    def contact(self):
 
        undefined_contact = u'Unknown'
 
        return safe_unicode(get_contact(self._repo.ui.config)
 
                            or undefined_contact)
 

	
 
    @LazyProperty
 
    def last_change(self):
 
        """
 
        Returns last change made on this repository as datetime object
 
        """
 
        return date_fromtimestamp(self._get_mtime(), makedate()[1])
 

	
 
    def _get_mtime(self):
 
        try:
 
            return time.mktime(self.get_changeset().date.timetuple())
 
        except RepositoryError:
 
            #fallback to filesystem
 
            cl_path = os.path.join(self.path, '.hg', "00changelog.i")
 
            st_path = os.path.join(self.path, '.hg', "store")
 
            if os.path.exists(cl_path):
 
                return os.stat(cl_path).st_mtime
 
            else:
 
                return os.stat(st_path).st_mtime
 

	
 
    def _get_revision(self, revision):
 
        """
 
        Gets an ID revision given as str. This will always return a fill
 
        40 char revision number
 

	
 
        :param revision: str or int or None
 
        """
 
        if isinstance(revision, unicode):
 
            revision = safe_str(revision)
 

	
 
        if self._empty:
 
            raise EmptyRepositoryError("There are no changesets yet")
 

	
 
        if revision in [-1, 'tip', None]:
 
            revision = 'tip'
 

	
 
        try:
 
            revision = hex(self._repo.lookup(revision))
 
        except (LookupError, ):
 
            msg = ("Ambiguous identifier `%s` for %s" % (revision, self))
 
            raise ChangesetDoesNotExistError(msg)
 
        except (IndexError, ValueError, RepoLookupError, TypeError):
 
            msg = ("Revision %s does not exist for %s" % (revision, self))
 
            raise ChangesetDoesNotExistError(msg)
 

	
 
        return revision
 

	
 
    def get_ref_revision(self, ref_type, ref_name):
 
        """
 
        Returns revision number for the given reference.
 
        """
 
        ref_name = safe_str(ref_name)
 
        if ref_type == 'rev' and not ref_name.strip('0'):
 
            return self.EMPTY_CHANGESET
 
        # lookup up the exact node id
 
        _revset_predicates = {
 
                'branch': 'branch',
 
                'book': 'bookmark',
 
                'tag': 'tag',
 
                'rev': 'id',
 
            }
 
        # avoid expensive branch(x) iteration over whole repo
 
        rev_spec = "%%s & %s(%%s)" % _revset_predicates[ref_type]
 
        try:
 
            revs = self._repo.revs(rev_spec, ref_name, ref_name)
 
        except LookupError:
 
            msg = ("Ambiguous identifier %s:%s for %s" % (ref_type, ref_name, self.name))
 
            raise ChangesetDoesNotExistError(msg)
 
        except RepoLookupError:
 
            msg = ("Revision %s:%s does not exist for %s" % (ref_type, ref_name, self.name))
 
            raise ChangesetDoesNotExistError(msg)
 
        if revs:
 
            revision = revs[-1]
 
        else:
 
            # TODO: just report 'not found'?
 
            revision = ref_name
 

	
 
        return self._get_revision(revision)
 

	
 
    def _get_archives(self, archive_name='tip'):
 
        allowed = self.baseui.configlist("web", "allow_archive",
 
                                         untrusted=True)
 
        for i in [('zip', '.zip'), ('gz', '.tar.gz'), ('bz2', '.tar.bz2')]:
 
            if i[0] in allowed or self._repo.ui.configbool("web",
 
                                                           "allow" + i[0],
 
                                                           untrusted=True):
 
                yield {"type": i[0], "extension": i[1], "node": archive_name}
 

	
 
    def _get_url(self, url):
 
        """
 
        Returns normalized url. If schema is not given, would fall
 
        to filesystem
 
        (``file:///``) schema.
 
        """
 
        url = str(url)
 
        if url != 'default' and not '://' in url:
 
            url = "file:" + urllib.pathname2url(url)
 
        return url
 

	
 
    def get_hook_location(self):
 
        """
 
        returns absolute path to location where hooks are stored
 
        """
 
        return os.path.join(self.path, '.hg', '.hgrc')
 

	
 
    def get_changeset(self, revision=None):
 
        """
 
        Returns ``MercurialChangeset`` object representing repository's
 
        changeset at the given ``revision``.
 
        """
 
        revision = self._get_revision(revision)
 
        changeset = MercurialChangeset(repository=self, revision=revision)
 
        return changeset
 

	
 
    def get_changesets(self, start=None, end=None, start_date=None,
 
                       end_date=None, branch_name=None, reverse=False):
 
        """
 
        Returns iterator of ``MercurialChangeset`` objects from start to end
 
        (both are inclusive)
 

	
 
        :param start: None, str, int or mercurial lookup format
 
        :param end:  None, str, int or mercurial lookup format
 
        :param start_date:
 
        :param end_date:
 
        :param branch_name:
 
        :param reversed: return changesets in reversed order
 
        """
 

	
 
        start_raw_id = self._get_revision(start)
 
        start_pos = self.revisions.index(start_raw_id) if start else None
 
        end_raw_id = self._get_revision(end)
 
        end_pos = self.revisions.index(end_raw_id) if end else None
 

	
 
        if None not in [start, end] and start_pos > end_pos:
 
            raise RepositoryError("Start revision '%s' cannot be "
 
                                  "after end revision '%s'" % (start, end))
 

	
 
        if branch_name and branch_name not in self.allbranches.keys():
 
            msg = ("Branch %s not found in %s" % (branch_name, self))
 
            raise BranchDoesNotExistError(msg)
 
        if end_pos is not None:
 
            end_pos += 1
 
        #filter branches
 
        filter_ = []
 
        if branch_name:
 
            filter_.append('branch("%s")' % (branch_name))
 

	
 
        if start_date and not end_date:
 
            filter_.append('date(">%s")' % start_date)
 
        if end_date and not start_date:
 
            filter_.append('date("<%s")' % end_date)
 
        if start_date and end_date:
 
            filter_.append('date(">%s") and date("<%s")' % (start_date, end_date))
 
        if filter_:
 
            revisions = scmutil.revrange(self._repo, filter_)
 
        else:
 
            revisions = self.revisions
 

	
 
        revs = revisions[start_pos:end_pos]
 
        if reverse:
 
            revs = reversed(revs)
 

	
 
        return CollectionGenerator(self, revs)
 

	
 
    def pull(self, url):
 
        """
 
        Tries to pull changes from external location.
 
        """
 
        url = self._get_url(url)
 
        try:
 
            other = peer(self._repo, {}, url)
 
            self._repo.pull(other, heads=None, force=None)
 
        except Abort, err:
 
            # Propagate error but with vcs's type
 
            raise RepositoryError(str(err))
 

	
 
    @LazyProperty
 
    def workdir(self):
 
        """
 
        Returns ``Workdir`` instance for this repository.
 
        """
 
        return MercurialWorkdir(self)
 

	
 
    def get_config_value(self, section, name=None, config_file=None):
 
        """
 
        Returns configuration value for a given [``section``] and ``name``.
 

	
 
        :param section: Section we want to retrieve value from
 
        :param name: Name of configuration we want to retrieve
 
        :param config_file: A path to file which should be used to retrieve
 
          configuration from (might also be a list of file paths)
 
        """
 
        if config_file is None:
 
            config_file = []
 
        elif isinstance(config_file, basestring):
 
            config_file = [config_file]
 

	
 
        config = self._repo.ui
 
        for path in config_file:
 
            config.readconfig(path)
 
        return config.config(section, name)
 

	
 
    def get_user_name(self, config_file=None):
 
        """
 
        Returns user's name from global configuration file.
 

	
 
        :param config_file: A path to file which should be used to retrieve
 
          configuration from (might also be a list of file paths)
 
        """
 
        username = self.get_config_value('ui', 'username')
 
        if username:
 
            return author_name(username)
 
        return None
 

	
 
    def get_user_email(self, config_file=None):
 
        """
 
        Returns user's email from global configuration file.
 

	
 
        :param config_file: A path to file which should be used to retrieve
 
          configuration from (might also be a list of file paths)
 
        """
 
        username = self.get_config_value('ui', 'username')
 
        if username:
 
            return author_email(username)
 
        return None
kallithea/lib/vcs/nodes.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
"""
 
    vcs.nodes
 
    ~~~~~~~~~
 

	
 
    Module holding everything related to vcs nodes.
 

	
 
    :created_on: Apr 8, 2010
 
    :copyright: (c) 2010-2011 by Marcin Kuzminski, Lukasz Balcerzak.
 
"""
 

	
 
import stat
 
import posixpath
 
import mimetypes
 

	
 
from kallithea.lib.vcs.backends.base import EmptyChangeset
 
from kallithea.lib.vcs.exceptions import NodeError, RemovedFileNodeError
 
from kallithea.lib.vcs.utils.lazy import LazyProperty
 
from kallithea.lib.vcs.utils import safe_unicode, safe_str
 

	
 

	
 
class NodeKind:
 
    SUBMODULE = -1
 
    DIR = 1
 
    FILE = 2
 

	
 

	
 
class NodeState:
 
    ADDED = u'added'
 
    CHANGED = u'changed'
 
    NOT_CHANGED = u'not changed'
 
    REMOVED = u'removed'
 

	
 

	
 
class NodeGeneratorBase(object):
 
    """
 
    Base class for removed added and changed filenodes, it's a lazy generator
 
    class that will create filenodes only on iteration or call
 

	
 
    The len method doesn't need to create filenodes at all
 
    """
 

	
 
    def __init__(self, current_paths, cs):
 
        self.cs = cs
 
        self.current_paths = current_paths
 

	
 
    def __call__(self):
 
        return [n for n in self]
 

	
 
    def __getslice__(self, i, j):
 
        for p in self.current_paths[i:j]:
 
            yield self.cs.get_node(p)
 

	
 
    def __len__(self):
 
        return len(self.current_paths)
 

	
 
    def __iter__(self):
 
        for p in self.current_paths:
 
            yield self.cs.get_node(p)
 

	
 

	
 
class AddedFileNodesGenerator(NodeGeneratorBase):
 
    """
 
    Class holding Added files for current changeset
 
    """
 
    pass
 

	
 

	
 
class ChangedFileNodesGenerator(NodeGeneratorBase):
 
    """
 
    Class holding Changed files for current changeset
 
    """
 
    pass
 

	
 

	
 
class RemovedFileNodesGenerator(NodeGeneratorBase):
 
    """
 
    Class holding removed files for current changeset
 
    """
 
    def __iter__(self):
 
        for p in self.current_paths:
 
            yield RemovedFileNode(path=p)
 

	
 
    def __getslice__(self, i, j):
 
        for p in self.current_paths[i:j]:
 
            yield RemovedFileNode(path=p)
 

	
 

	
 
class Node(object):
 
    """
 
    Simplest class representing file or directory on repository.  SCM backends
 
    should use ``FileNode`` and ``DirNode`` subclasses rather than ``Node``
 
    directly.
 

	
 
    Node's ``path`` cannot start with slash as we operate on *relative* paths
 
    only. Moreover, every single node is identified by the ``path`` attribute,
 
    so it cannot end with slash, too. Otherwise, path could lead to mistakes.
 
    """
 

	
 
    def __init__(self, path, kind):
 
        if path.startswith('/'):
 
            raise NodeError("Cannot initialize Node objects with slash at "
 
                            "the beginning as only relative paths are supported")
 
        self.path = safe_str(path.rstrip('/'))  # we store paths as str
 
        if path == '' and kind != NodeKind.DIR:
 
            raise NodeError("Only DirNode and its subclasses may be "
 
                            "initialized with empty path")
 
        self.kind = kind
 
        #self.dirs, self.files = [], []
 
        if self.is_root() and not self.is_dir():
 
            raise NodeError("Root node cannot be FILE kind")
 

	
 
    @LazyProperty
 
    def parent(self):
 
        parent_path = self.get_parent_path()
 
        if parent_path:
 
            if self.changeset:
 
                return self.changeset.get_node(parent_path)
 
            return DirNode(parent_path)
 
        return None
 

	
 
    @LazyProperty
 
    def unicode_path(self):
 
        return safe_unicode(self.path)
 

	
 
    @LazyProperty
 
    def name(self):
 
        """
 
        Returns name of the node so if its path
 
        then only last part is returned.
 
        """
 
        return safe_unicode(self.path.rstrip('/').split('/')[-1])
 

	
 
    def _get_kind(self):
 
        return self._kind
 

	
 
    def _set_kind(self, kind):
 
        if hasattr(self, '_kind'):
 
            raise NodeError("Cannot change node's kind")
 
        else:
 
            self._kind = kind
 
            # Post setter check (path's trailing slash)
 
            if self.path.endswith('/'):
 
                raise NodeError("Node's path cannot end with slash")
 

	
 
    kind = property(_get_kind, _set_kind)
 

	
 
    def __cmp__(self, other):
 
        """
 
        Comparator using name of the node, needed for quick list sorting.
 
        """
 
        kind_cmp = cmp(self.kind, other.kind)
 
        if kind_cmp:
 
            return kind_cmp
 
        return cmp(self.name, other.name)
 

	
 
    def __eq__(self, other):
 
        for attr in ['name', 'path', 'kind']:
 
            if getattr(self, attr) != getattr(other, attr):
 
                return False
 
        if self.is_file():
 
            if self.content != other.content:
 
                return False
 
        else:
 
            # For DirNode's check without entering each dir
 
            self_nodes_paths = list(sorted(n.path for n in self.nodes))
 
            other_nodes_paths = list(sorted(n.path for n in self.nodes))
 
            if self_nodes_paths != other_nodes_paths:
 
                return False
 
        return True
 

	
 
    def __nq__(self, other):
 
        return not self.__eq__(other)
 

	
 
    def __repr__(self):
 
        return '<%s %r>' % (self.__class__.__name__, self.path)
 

	
 
    def __str__(self):
 
        return self.__repr__()
 

	
 
    def __unicode__(self):
 
        return self.name
 

	
 
    def get_parent_path(self):
 
        """
 
        Returns node's parent path or empty string if node is root.
 
        """
 
        if self.is_root():
 
            return ''
 
        return posixpath.dirname(self.path.rstrip('/')) + '/'
 

	
 
    def is_file(self):
 
        """
 
        Returns ``True`` if node's kind is ``NodeKind.FILE``, ``False``
 
        otherwise.
 
        """
 
        return self.kind == NodeKind.FILE
 

	
 
    def is_dir(self):
 
        """
 
        Returns ``True`` if node's kind is ``NodeKind.DIR``, ``False``
 
        otherwise.
 
        """
 
        return self.kind == NodeKind.DIR
 

	
 
    def is_root(self):
 
        """
 
        Returns ``True`` if node is a root node and ``False`` otherwise.
 
        """
 
        return self.kind == NodeKind.DIR and self.path == ''
 

	
 
    def is_submodule(self):
 
        """
 
        Returns ``True`` if node's kind is ``NodeKind.SUBMODULE``, ``False``
 
        otherwise.
 
        """
 
        return self.kind == NodeKind.SUBMODULE
 

	
 
    @LazyProperty
 
    def added(self):
 
        return self.state is NodeState.ADDED
 

	
 
    @LazyProperty
 
    def changed(self):
 
        return self.state is NodeState.CHANGED
 

	
 
    @LazyProperty
 
    def not_changed(self):
 
        return self.state is NodeState.NOT_CHANGED
 

	
 
    @LazyProperty
 
    def removed(self):
 
        return self.state is NodeState.REMOVED
 

	
 

	
 
class FileNode(Node):
 
    """
 
    Class representing file nodes.
 

	
 
    :attribute: path: path to the node, relative to repostiory's root
 
    :attribute: content: if given arbitrary sets content of the file
 
    :attribute: changeset: if given, first time content is accessed, callback
 
    :attribute: mode: octal stat mode for a node. Default is 0100644.
 
    """
 

	
 
    def __init__(self, path, content=None, changeset=None, mode=None):
 
        """
 
        Only one of ``content`` and ``changeset`` may be given. Passing both
 
        would raise ``NodeError`` exception.
 

	
 
        :param path: relative path to the node
 
        :param content: content may be passed to constructor
 
        :param changeset: if given, will use it to lazily fetch content
 
        :param mode: octal representation of ST_MODE (i.e. 0100644)
 
        """
 

	
 
        if content and changeset:
 
            raise NodeError("Cannot use both content and changeset")
 
        super(FileNode, self).__init__(path, kind=NodeKind.FILE)
 
        self.changeset = changeset
 
        self._content = content
 
        self._mode = mode or 0100644
 

	
 
    @LazyProperty
 
    def mode(self):
 
        """
 
        Returns lazily mode of the FileNode. If ``changeset`` is not set, would
 
        use value given at initialization or 0100644 (default).
 
        """
 
        if self.changeset:
 
            mode = self.changeset.get_file_mode(self.path)
 
        else:
 
            mode = self._mode
 
        return mode
 

	
 
    def _get_content(self):
 
        if self.changeset:
 
            content = self.changeset.get_file_content(self.path)
 
        else:
 
            content = self._content
 
        return content
 

	
 
    @property
 
    def content(self):
 
        """
 
        Returns lazily content of the FileNode. If possible, would try to
 
        decode content from UTF-8.
 
        """
 
        content = self._get_content()
 

	
 
        if bool(content and '\0' in content):
 
            return content
 
        return safe_unicode(content)
 

	
 
    @LazyProperty
 
    def size(self):
 
        if self.changeset:
 
            return self.changeset.get_file_size(self.path)
 
        raise NodeError("Cannot retrieve size of the file without related "
 
            "changeset attribute")
 

	
 
    @LazyProperty
 
    def message(self):
 
        if self.changeset:
 
            return self.last_changeset.message
 
        raise NodeError("Cannot retrieve message of the file without related "
 
            "changeset attribute")
 

	
 
    @LazyProperty
 
    def last_changeset(self):
 
        if self.changeset:
 
            return self.changeset.get_file_changeset(self.path)
 
        raise NodeError("Cannot retrieve last changeset of the file without "
 
            "related changeset attribute")
 

	
 
    def get_mimetype(self):
 
        """
 
        Mimetype is calculated based on the file's content. If ``_mimetype``
 
        attribute is available, it will be returned (backends which store
 
        mimetypes or can easily recognize them, should set this private
 
        attribute to indicate that type should *NOT* be calculated).
 
        """
 
        if hasattr(self, '_mimetype'):
 
            if (isinstance(self._mimetype, (tuple, list,)) and
 
                len(self._mimetype) == 2):
 
                return self._mimetype
 
            else:
 
                raise NodeError('given _mimetype attribute must be an 2 '
 
                                'element list or tuple')
 

	
 
        mtype, encoding = mimetypes.guess_type(self.name)
 

	
 
        if mtype is None:
 
            if self.is_binary:
 
                mtype = 'application/octet-stream'
 
                encoding = None
 
            else:
 
                mtype = 'text/plain'
 
                encoding = None
 

	
 
                #try with pygments
 
                try:
 
                    from pygments.lexers import get_lexer_for_filename
 
                    mt = get_lexer_for_filename(self.name).mimetypes
 
                except Exception:
 
                    from pygments import lexers
 
                    mt = lexers.get_lexer_for_filename(self.name).mimetypes
 
                except lexers.ClassNotFound:
 
                    mt = None
 

	
 
                if mt:
 
                    mtype = mt[0]
 

	
 
        return mtype, encoding
 

	
 
    @LazyProperty
 
    def mimetype(self):
 
        """
 
        Wrapper around full mimetype info. It returns only type of fetched
 
        mimetype without the encoding part. use get_mimetype function to fetch
 
        full set of (type,encoding)
 
        """
 
        return self.get_mimetype()[0]
 

	
 
    @LazyProperty
 
    def mimetype_main(self):
 
        return self.mimetype.split('/')[0]
 

	
 
    @LazyProperty
 
    def lexer(self):
 
        """
 
        Returns pygment's lexer class. Would try to guess lexer taking file's
 
        content, name and mimetype.
 
        """
 
        from pygments import lexers
 
        try:
 
            lexer = lexers.guess_lexer_for_filename(self.name, self.content, stripnl=False)
 
        except lexers.ClassNotFound:
 
            lexer = lexers.TextLexer(stripnl=False)
 
        # returns first alias
 
        return lexer
 

	
 
    @LazyProperty
 
    def lexer_alias(self):
 
        """
 
        Returns first alias of the lexer guessed for this file.
 
        """
 
        return self.lexer.aliases[0]
 

	
 
    @LazyProperty
 
    def history(self):
 
        """
 
        Returns a list of changeset for this file in which the file was changed
 
        """
 
        if self.changeset is None:
 
            raise NodeError('Unable to get changeset for this FileNode')
 
        return self.changeset.get_file_history(self.path)
 

	
 
    @LazyProperty
 
    def annotate(self):
 
        """
 
        Returns a list of three element tuples with lineno,changeset and line
 
        """
 
        if self.changeset is None:
 
            raise NodeError('Unable to get changeset for this FileNode')
 
        return self.changeset.get_file_annotate(self.path)
 

	
 
    @LazyProperty
 
    def state(self):
 
        if not self.changeset:
 
            raise NodeError("Cannot check state of the node if it's not "
 
                "linked with changeset")
 
        elif self.path in (node.path for node in self.changeset.added):
 
            return NodeState.ADDED
 
        elif self.path in (node.path for node in self.changeset.changed):
 
            return NodeState.CHANGED
 
        else:
 
            return NodeState.NOT_CHANGED
 

	
 
    @property
 
    def is_binary(self):
 
        """
 
        Returns True if file has binary content.
 
        """
 
        _bin = '\0' in self._get_content()
 
        return _bin
 

	
 
    def is_browser_compatible_image(self):
 
        return self.mimetype in [
 
            "image/gif",
 
            "image/jpeg",
 
            "image/png",
 
            "image/bmp"
 
        ]
 

	
 
    @LazyProperty
 
    def extension(self):
 
        """Returns filenode extension"""
 
        return self.name.split('.')[-1]
 

	
 
    @property
 
    def is_executable(self):
 
        """
 
        Returns ``True`` if file has executable flag turned on.
 
        """
 
        return bool(self.mode & stat.S_IXUSR)
 

	
 
    def __repr__(self):
 
        return '<%s %r @ %s>' % (self.__class__.__name__, self.path,
 
                                 getattr(self.changeset, 'short_id', ''))
 

	
 

	
 
class RemovedFileNode(FileNode):
 
    """
 
    Dummy FileNode class - trying to access any public attribute except path,
 
    name, kind or state (or methods/attributes checking those two) would raise
 
    RemovedFileNodeError.
 
    """
 
    ALLOWED_ATTRIBUTES = [
 
        'name', 'path', 'state', 'is_root', 'is_file', 'is_dir', 'kind',
 
        'added', 'changed', 'not_changed', 'removed'
 
    ]
 

	
 
    def __init__(self, path):
 
        """
 
        :param path: relative path to the node
 
        """
 
        super(RemovedFileNode, self).__init__(path=path)
 

	
 
    def __getattribute__(self, attr):
 
        if attr.startswith('_') or attr in RemovedFileNode.ALLOWED_ATTRIBUTES:
 
            return super(RemovedFileNode, self).__getattribute__(attr)
 
        raise RemovedFileNodeError("Cannot access attribute %s on "
 
            "RemovedFileNode" % attr)
 

	
 
    @LazyProperty
 
    def state(self):
 
        return NodeState.REMOVED
 

	
 

	
 
class DirNode(Node):
 
    """
 
    DirNode stores list of files and directories within this node.
 
    Nodes may be used standalone but within repository context they
 
    lazily fetch data within same repositorty's changeset.
 
    """
 

	
 
    def __init__(self, path, nodes=(), changeset=None):
 
        """
 
        Only one of ``nodes`` and ``changeset`` may be given. Passing both
 
        would raise ``NodeError`` exception.
 

	
 
        :param path: relative path to the node
 
        :param nodes: content may be passed to constructor
 
        :param changeset: if given, will use it to lazily fetch content
 
        :param size: always 0 for ``DirNode``
 
        """
 
        if nodes and changeset:
 
            raise NodeError("Cannot use both nodes and changeset")
 
        super(DirNode, self).__init__(path, NodeKind.DIR)
 
        self.changeset = changeset
 
        self._nodes = nodes
 

	
 
    @LazyProperty
 
    def content(self):
 
        raise NodeError("%s represents a dir and has no ``content`` attribute"
 
            % self)
 

	
 
    @LazyProperty
 
    def nodes(self):
 
        if self.changeset:
 
            nodes = self.changeset.get_nodes(self.path)
 
        else:
 
            nodes = self._nodes
 
        self._nodes_dict = dict((node.path, node) for node in nodes)
 
        return sorted(nodes)
 

	
 
    @LazyProperty
 
    def files(self):
 
        return sorted((node for node in self.nodes if node.is_file()))
 

	
 
    @LazyProperty
 
    def dirs(self):
 
        return sorted((node for node in self.nodes if node.is_dir()))
 

	
 
    def __iter__(self):
 
        for node in self.nodes:
 
            yield node
 

	
 
    def get_node(self, path):
 
        """
 
        Returns node from within this particular ``DirNode``, so it is now
 
        allowed to fetch, i.e. node located at 'docs/api/index.rst' from node
 
        'docs'. In order to access deeper nodes one must fetch nodes between
 
        them first - this would work::
 

	
 
           docs = root.get_node('docs')
 
           docs.get_node('api').get_node('index.rst')
 

	
 
        :param: path - relative to the current node
 

	
 
        .. note::
 
           To access lazily (as in example above) node have to be initialized
 
           with related changeset object - without it node is out of
 
           context and may know nothing about anything else than nearest
 
           (located at same level) nodes.
 
        """
 
        try:
 
            path = path.rstrip('/')
 
            if path == '':
 
                raise NodeError("Cannot retrieve node without path")
 
            self.nodes  # access nodes first in order to set _nodes_dict
 
            paths = path.split('/')
 
            if len(paths) == 1:
 
                if not self.is_root():
 
                    path = '/'.join((self.path, paths[0]))
 
                else:
 
                    path = paths[0]
 
                return self._nodes_dict[path]
 
            elif len(paths) > 1:
 
                if self.changeset is None:
 
                    raise NodeError("Cannot access deeper "
 
                                    "nodes without changeset")
 
                else:
 
                    path1, path2 = paths[0], '/'.join(paths[1:])
 
                    return self.get_node(path1).get_node(path2)
 
            else:
 
                raise KeyError
 
        except KeyError:
 
            raise NodeError("Node does not exist at %s" % path)
 

	
 
    @LazyProperty
 
    def state(self):
 
        raise NodeError("Cannot access state of DirNode")
 

	
 
    @LazyProperty
 
    def size(self):
 
        size = 0
 
        for root, dirs, files in self.changeset.walk(self.path):
 
            for f in files:
 
                size += f.size
 

	
 
        return size
 

	
 
    def __repr__(self):
 
        return '<%s %r @ %s>' % (self.__class__.__name__, self.path,
 
                                 getattr(self.changeset, 'short_id', ''))
 

	
 

	
 
class RootNode(DirNode):
 
    """
 
    DirNode being the root node of the repository.
 
    """
 

	
 
    def __init__(self, nodes=(), changeset=None):
 
        super(RootNode, self).__init__(path='', nodes=nodes,
 
            changeset=changeset)
 

	
 
    def __repr__(self):
 
        return '<%s>' % self.__class__.__name__
 

	
 

	
 
class SubModuleNode(Node):
 
    """
 
    represents a SubModule of Git or SubRepo of Mercurial
 
    """
 
    is_binary = False
 
    size = 0
 

	
 
    def __init__(self, name, url=None, changeset=None, alias=None):
 
        self.path = name
 
        self.kind = NodeKind.SUBMODULE
 
        self.alias = alias
 
        # we have to use emptyChangeset here since this can point to svn/git/hg
 
        # submodules we cannot get from repository
 
        self.changeset = EmptyChangeset(str(changeset), alias=alias)
 
        self.url = url or self._extract_submodule_url()
 

	
 
    def __repr__(self):
 
        return '<%s %r @ %s>' % (self.__class__.__name__, self.path,
 
                                 getattr(self.changeset, 'short_id', ''))
 

	
 
    def _extract_submodule_url(self):
 
        if self.alias == 'git':
 
            #TODO: find a way to parse gits submodule file and extract the
 
            # linking URL
 
            return self.path
 
        if self.alias == 'hg':
 
            return self.path
 

	
 
    @LazyProperty
 
    def name(self):
 
        """
 
        Returns name of the node so if its path
 
        then only last part is returned.
 
        """
 
        org = safe_unicode(self.path.rstrip('/').split('/')[-1])
 
        return u'%s @ %s' % (org, self.changeset.short_id)
kallithea/lib/vcs/subprocessio.py
Show inline comments
 
"""
 
Module provides a class allowing to wrap communication over subprocess.Popen
 
input, output, error streams into a meaningfull, non-blocking, concurrent
 
stream processor exposing the output data as an iterator fitting to be a
 
return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
 

	
 
Copyright (c) 2011  Daniel Dotsenko <dotsa[at]hotmail.com>
 

	
 
This file is part of git_http_backend.py Project.
 

	
 
git_http_backend.py Project is free software: you can redistribute it and/or
 
modify it under the terms of the GNU Lesser General Public License as
 
published by the Free Software Foundation, either version 2.1 of the License,
 
or (at your option) any later version.
 

	
 
git_http_backend.py Project is distributed in the hope that it will be useful,
 
but WITHOUT ANY WARRANTY; without even the implied warranty of
 
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 
GNU Lesser General Public License for more details.
 

	
 
You should have received a copy of the GNU Lesser General Public License
 
along with git_http_backend.py Project.
 
If not, see <http://www.gnu.org/licenses/>.
 
"""
 
import os
 
import subprocess
 
from kallithea.lib.vcs.utils.compat import deque, Event, Thread, _bytes, _bytearray
 

	
 

	
 
class StreamFeeder(Thread):
 
    """
 
    Normal writing into pipe-like is blocking once the buffer is filled.
 
    This thread allows a thread to seep data from a file-like into a pipe
 
    without blocking the main thread.
 
    We close inpipe once the end of the source stream is reached.
 
    """
 

	
 
    def __init__(self, source):
 
        super(StreamFeeder, self).__init__()
 
        self.daemon = True
 
        filelike = False
 
        self.bytes = _bytes()
 
        if type(source) in (type(''), _bytes, _bytearray):  # string-like
 
            self.bytes = _bytes(source)
 
        else:  # can be either file pointer or file-like
 
            if type(source) in (int, long):  # file pointer it is
 
                ## converting file descriptor (int) stdin into file-like
 
                try:
 
                    source = os.fdopen(source, 'rb', 16384)
 
                except Exception:
 
                    pass
 
                source = os.fdopen(source, 'rb', 16384)
 
            # let's see if source is file-like by now
 
            try:
 
                filelike = source.read
 
            except Exception:
 
                pass
 
            filelike = hasattr(source, 'read')
 
        if not filelike and not self.bytes:
 
            raise TypeError("StreamFeeder's source object must be a readable "
 
                            "file-like, a file descriptor, or a string-like.")
 
        self.source = source
 
        self.readiface, self.writeiface = os.pipe()
 

	
 
    def run(self):
 
        t = self.writeiface
 
        if self.bytes:
 
            os.write(t, self.bytes)
 
        else:
 
            s = self.source
 
            b = s.read(4096)
 
            while b:
 
                os.write(t, b)
 
                b = s.read(4096)
 
        os.close(t)
 

	
 
    @property
 
    def output(self):
 
        return self.readiface
 

	
 

	
 
class InputStreamChunker(Thread):
 
    def __init__(self, source, target, buffer_size, chunk_size):
 

	
 
        super(InputStreamChunker, self).__init__()
 

	
 
        self.daemon = True  # die die die.
 

	
 
        self.source = source
 
        self.target = target
 
        self.chunk_count_max = int(buffer_size / chunk_size) + 1
 
        self.chunk_size = chunk_size
 

	
 
        self.data_added = Event()
 
        self.data_added.clear()
 

	
 
        self.keep_reading = Event()
 
        self.keep_reading.set()
 

	
 
        self.EOF = Event()
 
        self.EOF.clear()
 

	
 
        self.go = Event()
 
        self.go.set()
 

	
 
    def stop(self):
 
        self.go.clear()
 
        self.EOF.set()
 
        try:
 
            # this is not proper, but is done to force the reader thread let
 
            # go of the input because, if successful, .close() will send EOF
 
            # down the pipe.
 
            self.source.close()
 
        except:
 
            pass
 

	
 
    def run(self):
 
        s = self.source
 
        t = self.target
 
        cs = self.chunk_size
 
        ccm = self.chunk_count_max
 
        kr = self.keep_reading
 
        da = self.data_added
 
        go = self.go
 

	
 
        try:
 
            b = s.read(cs)
 
        except ValueError:
 
            b = ''
 

	
 
        while b and go.is_set():
 
            if len(t) > ccm:
 
                kr.clear()
 
                kr.wait(2)
 
                # # this only works on 2.7.x and up
 
                # if not kr.wait(10):
 
                #     raise Exception("Timed out while waiting for input to be read.")
 
                # instead we'll use this
 
                if len(t) > ccm + 3:
 
                    raise IOError(
 
                        "Timed out while waiting for input from subprocess.")
 
            t.append(b)
 
            da.set()
 
            b = s.read(cs)
 
        self.EOF.set()
 
        da.set()  # for cases when done but there was no input.
 

	
 

	
 
class BufferedGenerator(object):
 
    """
 
    Class behaves as a non-blocking, buffered pipe reader.
 
    Reads chunks of data (through a thread)
 
    from a blocking pipe, and attaches these to an array (Deque) of chunks.
 
    Reading is halted in the thread when max chunks is internally buffered.
 
    The .next() may operate in blocking or non-blocking fashion by yielding
 
    '' if no data is ready
 
    to be sent or by not returning until there is some data to send
 
    When we get EOF from underlying source pipe we raise the marker to raise
 
    StopIteration after the last chunk of data is yielded.
 
    """
 

	
 
    def __init__(self, source, buffer_size=65536, chunk_size=4096,
 
                 starting_values=[], bottomless=False):
 

	
 
        if bottomless:
 
            maxlen = int(buffer_size / chunk_size)
 
        else:
 
            maxlen = None
 

	
 
        self.data = deque(starting_values, maxlen)
 
        self.worker = InputStreamChunker(source, self.data, buffer_size,
 
                                         chunk_size)
 
        if starting_values:
 
            self.worker.data_added.set()
 
        self.worker.start()
 

	
 
    ####################
 
    # Generator's methods
 
    ####################
 

	
 
    def __iter__(self):
 
        return self
 

	
 
    def next(self):
 
        while not len(self.data) and not self.worker.EOF.is_set():
 
            self.worker.data_added.clear()
 
            self.worker.data_added.wait(0.2)
 
        if len(self.data):
 
            self.worker.keep_reading.set()
 
            return _bytes(self.data.popleft())
 
        elif self.worker.EOF.is_set():
 
            raise StopIteration
 

	
 
    def throw(self, type, value=None, traceback=None):
 
        if not self.worker.EOF.is_set():
 
            raise type(value)
 

	
 
    def start(self):
 
        self.worker.start()
 

	
 
    def stop(self):
 
        self.worker.stop()
 

	
 
    def close(self):
 
        try:
 
            self.worker.stop()
 
            self.throw(GeneratorExit)
 
        except (GeneratorExit, StopIteration):
 
            pass
 

	
 
    def __del__(self):
 
        self.close()
 

	
 
    ####################
 
    # Threaded reader's infrastructure.
 
    ####################
 
    @property
 
    def input(self):
 
        return self.worker.w
 

	
 
    @property
 
    def data_added_event(self):
 
        return self.worker.data_added
 

	
 
    @property
 
    def data_added(self):
 
        return self.worker.data_added.is_set()
 

	
 
    @property
 
    def reading_paused(self):
 
        return not self.worker.keep_reading.is_set()
 

	
 
    @property
 
    def done_reading_event(self):
 
        """
 
        Done_reding does not mean that the iterator's buffer is empty.
 
        Iterator might have done reading from underlying source, but the read
 
        chunks might still be available for serving through .next() method.
 

	
 
        :returns: An Event class instance.
 
        """
 
        return self.worker.EOF
 

	
 
    @property
 
    def done_reading(self):
 
        """
 
        Done_reding does not mean that the iterator's buffer is empty.
 
        Iterator might have done reading from underlying source, but the read
 
        chunks might still be available for serving through .next() method.
 

	
 
        :returns: An Bool value.
 
        """
 
        return self.worker.EOF.is_set()
 

	
 
    @property
 
    def length(self):
 
        """
 
        returns int.
 

	
 
        This is the lenght of the que of chunks, not the length of
 
        the combined contents in those chunks.
 

	
 
        __len__() cannot be meaningfully implemented because this
 
        reader is just flying throuh a bottomless pit content and
 
        can only know the lenght of what it already saw.
 

	
 
        If __len__() on WSGI server per PEP 3333 returns a value,
 
        the responce's length will be set to that. In order not to
 
        confuse WSGI PEP3333 servers, we will not implement __len__
 
        at all.
 
        """
 
        return len(self.data)
 

	
 
    def prepend(self, x):
 
        self.data.appendleft(x)
 

	
 
    def append(self, x):
 
        self.data.append(x)
 

	
 
    def extend(self, o):
 
        self.data.extend(o)
 

	
 
    def __getitem__(self, i):
 
        return self.data[i]
 

	
 

	
 
class SubprocessIOChunker(object):
 
    """
 
    Processor class wrapping handling of subprocess IO.
 

	
 
    In a way, this is a "communicate()" replacement with a twist.
 

	
 
    - We are multithreaded. Writing in and reading out, err are all sep threads.
 
    - We support concurrent (in and out) stream processing.
 
    - The output is not a stream. It's a queue of read string (bytes, not unicode)
 
      chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
 
    - We are non-blocking in more respects than communicate()
 
      (reading from subprocess out pauses when internal buffer is full, but
 
       does not block the parent calling code. On the flip side, reading from
 
       slow-yielding subprocess may block the iteration until data shows up. This
 
       does not block the parallel inpipe reading occurring parallel thread.)
 

	
 
    The purpose of the object is to allow us to wrap subprocess interactions into
 
    and interable that can be passed to a WSGI server as the application's return
 
    value. Because of stream-processing-ability, WSGI does not have to read ALL
 
    of the subprocess's output and buffer it, before handing it to WSGI server for
 
    HTTP response. Instead, the class initializer reads just a bit of the stream
 
    to figure out if error ocurred or likely to occur and if not, just hands the
 
    further iteration over subprocess output to the server for completion of HTTP
 
    response.
 

	
 
    The real or perceived subprocess error is trapped and raised as one of
 
    EnvironmentError family of exceptions
 

	
 
    Example usage:
 
    #    try:
 
    #        answer = SubprocessIOChunker(
 
    #            cmd,
 
    #            input,
 
    #            buffer_size = 65536,
 
    #            chunk_size = 4096
 
    #            )
 
    #    except (EnvironmentError) as e:
 
    #        print str(e)
 
    #        raise e
 
    #
 
    #    return answer
 

	
 

	
 
    """
 

	
 
    def __init__(self, cmd, inputstream=None, buffer_size=65536,
 
                 chunk_size=4096, starting_values=[], **kwargs):
 
        """
 
        Initializes SubprocessIOChunker
 

	
 
        :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
 
        :param inputstream: (Default: None) A file-like, string, or file pointer.
 
        :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
 
        :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
 
        :param starting_values: (Default: []) An array of strings to put in front of output que.
 
        """
 

	
 
        if inputstream:
 
            input_streamer = StreamFeeder(inputstream)
 
            input_streamer.start()
 
            inputstream = input_streamer.output
 

	
 
        _shell = kwargs.get('shell', True)
 
        if isinstance(cmd, (list, tuple)):
 
            cmd = ' '.join(cmd)
 

	
 
        kwargs['shell'] = _shell
 
        _p = subprocess.Popen(cmd, bufsize=-1,
 
                              stdin=inputstream,
 
                              stdout=subprocess.PIPE,
 
                              stderr=subprocess.PIPE,
 
                              **kwargs)
 

	
 
        bg_out = BufferedGenerator(_p.stdout, buffer_size, chunk_size,
 
                                   starting_values)
 
        bg_err = BufferedGenerator(_p.stderr, 16000, 1, bottomless=True)
 

	
 
        while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
 
            # doing this until we reach either end of file, or end of buffer.
 
            bg_out.data_added_event.wait(1)
 
            bg_out.data_added_event.clear()
 

	
 
        # at this point it's still ambiguous if we are done reading or just full buffer.
 
        # Either way, if error (returned by ended process, or implied based on
 
        # presence of stuff in stderr output) we error out.
 
        # Else, we are happy.
 
        _returncode = _p.poll()
 
        if _returncode or (_returncode is None and bg_err.length):
 
            try:
 
                _p.terminate()
 
            except Exception:
 
                pass
 
            bg_out.stop()
 
            out = ''.join(bg_out)
 
            bg_err.stop()
 
            err = ''.join(bg_err)
 
            if (err.strip() == 'fatal: The remote end hung up unexpectedly' and
 
                out.startswith('0034shallow ')):
 
                # hack inspired by https://github.com/schacon/grack/pull/7
 
                bg_out = iter([out])
 
                _p = None
 
            elif err:
 
                raise EnvironmentError(
 
                    "Subprocess exited due to an error:\n" + err)
 
            else:
 
                raise EnvironmentError(
 
                    "Subprocess exited with non 0 ret code:%s" % _returncode)
 
        self.process = _p
 
        self.output = bg_out
 
        self.error = bg_err
 
        self.inputstream = inputstream
 

	
 
    def __iter__(self):
 
        return self
 

	
 
    def next(self):
 
        if self.process and self.process.poll():
 
            err = '%s' % ''.join(self.error)
 
            raise EnvironmentError("Subprocess exited due to an error:\n" + err)
 
        return self.output.next()
 

	
 
    def throw(self, type, value=None, traceback=None):
 
        if self.output.length or not self.output.done_reading:
 
            raise type(value)
 

	
 
    def close(self):
 
        try:
 
            self.process.terminate()
 
        except:
 
            pass
 
        try:
 
            self.output.close()
 
        except:
 
            pass
 
        try:
 
            self.error.close()
 
        except:
 
            pass
 
        try:
 
            os.close(self.inputstream)
 
        except:
 
            pass
 

	
 
    def __del__(self):
 
        self.close()
kallithea/model/api_key.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.model.api_key
 
~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
api key model for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Sep 8, 2013
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
from __future__ import with_statement
 
import time
 
import logging
 
import traceback
 
from sqlalchemy import or_
 

	
 
from kallithea.lib.utils2 import generate_api_key
 
from kallithea.model import BaseModel
 
from kallithea.model.db import UserApiKeys
 
from kallithea.model.meta import Session
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class ApiKeyModel(BaseModel):
 
    cls = UserApiKeys
 

	
 
    def create(self, user, description, lifetime=-1):
 
        """
 
        :param user: user or user_id
 
        :param description: description of ApiKey
 
        :param lifetime: expiration time in seconds
 
        """
 
        user = self._get_user(user)
 

	
 
        new_api_key = UserApiKeys()
 
        new_api_key.api_key = generate_api_key(user.username)
 
        new_api_key.user_id = user.user_id
 
        new_api_key.description = description
 
        new_api_key.expires = time.time() + (lifetime * 60) if lifetime != -1 else -1
 
        Session().add(new_api_key)
 

	
 
        return new_api_key
 

	
 
    def delete(self, api_key, user=None):
 
        """
 
        Deletes given api_key, if user is set it also filters the object for
 
        deletion by given user.
 
        """
 
        api_key = UserApiKeys.query().filter(UserApiKeys.api_key == api_key)
 

	
 
        if user:
 
            user = self._get_user(user)
 
            api_key = api_key.filter(UserApiKeys.user_id == user.user_id)
 

	
 
        api_key = api_key.scalar()
 
        try:
 
            Session().delete(api_key)
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            raise
 
        Session().delete(api_key)
 

	
 
    def get_api_keys(self, user, show_expired=True):
 
        user = self._get_user(user)
 
        user_api_keys = UserApiKeys.query()\
 
            .filter(UserApiKeys.user_id == user.user_id)
 
        if not show_expired:
 
            user_api_keys = user_api_keys\
 
                .filter(or_(UserApiKeys.expires == -1,
 
                            UserApiKeys.expires >= time.time()))
 
        return user_api_keys
kallithea/model/db.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.model.db
 
~~~~~~~~~~~~~~~~~~
 

	
 
Database Models for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 08, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 
import os
 
import time
 
import logging
 
import datetime
 
import traceback
 
import hashlib
 
import collections
 
import functools
 

	
 
from sqlalchemy import *
 
from sqlalchemy.ext.hybrid import hybrid_property
 
from sqlalchemy.orm import relationship, joinedload, class_mapper, validates
 
from beaker.cache import cache_region, region_invalidate
 
from webob.exc import HTTPNotFound
 

	
 
from pylons.i18n.translation import lazy_ugettext as _
 

	
 
from kallithea import DB_PREFIX
 
from kallithea.lib.vcs import get_backend
 
from kallithea.lib.vcs.utils.helpers import get_scm
 
from kallithea.lib.vcs.exceptions import VCSError
 
from kallithea.lib.vcs.utils.lazy import LazyProperty
 
from kallithea.lib.vcs.backends.base import EmptyChangeset
 

	
 
from kallithea.lib.utils2 import str2bool, safe_str, get_changeset_safe, \
 
    safe_unicode, remove_prefix, time_to_datetime, aslist, Optional, safe_int, \
 
    get_clone_url, urlreadable
 
from kallithea.lib.compat import json
 
from kallithea.lib.caching_query import FromCache
 

	
 
from kallithea.model.meta import Base, Session
 

	
 
URL_SEP = '/'
 
log = logging.getLogger(__name__)
 

	
 
#==============================================================================
 
# BASE CLASSES
 
#==============================================================================
 

	
 
_hash_key = lambda k: hashlib.md5(safe_str(k)).hexdigest()
 

	
 

	
 
class BaseModel(object):
 
    """
 
    Base Model for all classess
 
    """
 

	
 
    @classmethod
 
    def _get_keys(cls):
 
        """return column names for this model """
 
        return class_mapper(cls).c.keys()
 

	
 
    def get_dict(self):
 
        """
 
        return dict with keys and values corresponding
 
        to this model data """
 

	
 
        d = {}
 
        for k in self._get_keys():
 
            d[k] = getattr(self, k)
 

	
 
        # also use __json__() if present to get additional fields
 
        _json_attr = getattr(self, '__json__', None)
 
        if _json_attr:
 
            # update with attributes from __json__
 
            if callable(_json_attr):
 
                _json_attr = _json_attr()
 
            for k, val in _json_attr.iteritems():
 
                d[k] = val
 
        return d
 

	
 
    def get_appstruct(self):
 
        """return list with keys and values tupples corresponding
 
        to this model data """
 

	
 
        l = []
 
        for k in self._get_keys():
 
            l.append((k, getattr(self, k),))
 
        return l
 

	
 
    def populate_obj(self, populate_dict):
 
        """populate model with data from given populate_dict"""
 

	
 
        for k in self._get_keys():
 
            if k in populate_dict:
 
                setattr(self, k, populate_dict[k])
 

	
 
    @classmethod
 
    def query(cls):
 
        return Session().query(cls)
 

	
 
    @classmethod
 
    def get(cls, id_):
 
        if id_:
 
            return cls.query().get(id_)
 

	
 
    @classmethod
 
    def get_or_404(cls, id_):
 
        try:
 
            id_ = int(id_)
 
        except (TypeError, ValueError):
 
            raise HTTPNotFound
 

	
 
        res = cls.query().get(id_)
 
        if not res:
 
            raise HTTPNotFound
 
        return res
 

	
 
    @classmethod
 
    def getAll(cls):
 
        # deprecated and left for backward compatibility
 
        return cls.get_all()
 

	
 
    @classmethod
 
    def get_all(cls):
 
        return cls.query().all()
 

	
 
    @classmethod
 
    def delete(cls, id_):
 
        obj = cls.query().get(id_)
 
        Session().delete(obj)
 

	
 
    def __repr__(self):
 
        if hasattr(self, '__unicode__'):
 
            # python repr needs to return str
 
            try:
 
                return safe_str(self.__unicode__())
 
            except UnicodeDecodeError:
 
                pass
 
        return '<DB:%s>' % (self.__class__.__name__)
 

	
 

	
 
class Setting(Base, BaseModel):
 
    __tablename__ = DB_PREFIX + 'settings'
 

	
 
    __table_args__ = (
 
        UniqueConstraint('app_settings_name'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 

	
 
    SETTINGS_TYPES = {
 
        'str': safe_str,
 
        'int': safe_int,
 
        'unicode': safe_unicode,
 
        'bool': str2bool,
 
        'list': functools.partial(aslist, sep=',')
 
    }
 
    DEFAULT_UPDATE_URL = ''
 

	
 
    app_settings_id = Column("app_settings_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    app_settings_name = Column("app_settings_name", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    _app_settings_value = Column("app_settings_value", String(4096, convert_unicode=False), nullable=True, unique=None, default=None)
 
    _app_settings_type = Column("app_settings_type", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 

	
 
    def __init__(self, key='', val='', type='unicode'):
 
        self.app_settings_name = key
 
        self.app_settings_value = val
 
        self.app_settings_type = type
 

	
 
    @validates('_app_settings_value')
 
    def validate_settings_value(self, key, val):
 
        assert type(val) == unicode
 
        return val
 

	
 
    @hybrid_property
 
    def app_settings_value(self):
 
        v = self._app_settings_value
 
        _type = self.app_settings_type
 
        converter = self.SETTINGS_TYPES.get(_type) or self.SETTINGS_TYPES['unicode']
 
        return converter(v)
 

	
 
    @app_settings_value.setter
 
    def app_settings_value(self, val):
 
        """
 
        Setter that will always make sure we use unicode in app_settings_value
 

	
 
        :param val:
 
        """
 
        self._app_settings_value = safe_unicode(val)
 

	
 
    @hybrid_property
 
    def app_settings_type(self):
 
        return self._app_settings_type
 

	
 
    @app_settings_type.setter
 
    def app_settings_type(self, val):
 
        if val not in self.SETTINGS_TYPES:
 
            raise Exception('type must be one of %s got %s'
 
                            % (self.SETTINGS_TYPES.keys(), val))
 
        self._app_settings_type = val
 

	
 
    def __unicode__(self):
 
        return u"<%s('%s:%s[%s]')>" % (
 
            self.__class__.__name__,
 
            self.app_settings_name, self.app_settings_value, self.app_settings_type
 
        )
 

	
 
    @classmethod
 
    def get_by_name(cls, key):
 
        return cls.query()\
 
            .filter(cls.app_settings_name == key).scalar()
 

	
 
    @classmethod
 
    def get_by_name_or_create(cls, key, val='', type='unicode'):
 
        res = cls.get_by_name(key)
 
        if not res:
 
            res = cls(key, val, type)
 
        return res
 

	
 
    @classmethod
 
    def create_or_update(cls, key, val=Optional(''), type=Optional('unicode')):
 
        """
 
        Creates or updates Kallithea setting. If updates are triggered, it will only
 
        update parameters that are explicitly set. Optional instance will be skipped.
 

	
 
        :param key:
 
        :param val:
 
        :param type:
 
        :return:
 
        """
 
        res = cls.get_by_name(key)
 
        if not res:
 
            val = Optional.extract(val)
 
            type = Optional.extract(type)
 
            res = cls(key, val, type)
 
        else:
 
            res.app_settings_name = key
 
            if not isinstance(val, Optional):
 
                # update if set
 
                res.app_settings_value = val
 
            if not isinstance(type, Optional):
 
                # update if set
 
                res.app_settings_type = type
 
        return res
 

	
 
    @classmethod
 
    def get_app_settings(cls, cache=False):
 

	
 
        ret = cls.query()
 

	
 
        if cache:
 
            ret = ret.options(FromCache("sql_cache_short", "get_hg_settings"))
 

	
 
        if not ret:
 
            raise Exception('Could not get application settings !')
 
        settings = {}
 
        for each in ret:
 
            settings[each.app_settings_name] = \
 
                each.app_settings_value
 

	
 
        return settings
 

	
 
    @classmethod
 
    def get_auth_plugins(cls, cache=False):
 
        auth_plugins = cls.get_by_name("auth_plugins").app_settings_value
 
        return auth_plugins
 

	
 
    @classmethod
 
    def get_auth_settings(cls, cache=False):
 
        ret = cls.query()\
 
                .filter(cls.app_settings_name.startswith('auth_')).all()
 
        fd = {}
 
        for row in ret:
 
            fd[row.app_settings_name] = row.app_settings_value
 
        return fd
 

	
 
    @classmethod
 
    def get_default_repo_settings(cls, cache=False, strip_prefix=False):
 
        ret = cls.query()\
 
                .filter(cls.app_settings_name.startswith('default_')).all()
 
        fd = {}
 
        for row in ret:
 
            key = row.app_settings_name
 
            if strip_prefix:
 
                key = remove_prefix(key, prefix='default_')
 
            fd.update({key: row.app_settings_value})
 

	
 
        return fd
 

	
 
    @classmethod
 
    def get_server_info(cls):
 
        import pkg_resources
 
        import platform
 
        import kallithea
 
        from kallithea.lib.utils import check_git_version
 
        mods = [(p.project_name, p.version) for p in pkg_resources.working_set]
 
        info = {
 
            'modules': sorted(mods, key=lambda k: k[0].lower()),
 
            'py_version': platform.python_version(),
 
            'platform': safe_unicode(platform.platform()),
 
            'kallithea_version': kallithea.__version__,
 
            'git_version': safe_unicode(check_git_version()),
 
            'git_path': kallithea.CONFIG.get('git_path')
 
        }
 
        return info
 

	
 

	
 
class Ui(Base, BaseModel):
 
    __tablename__ = DB_PREFIX + 'ui'
 
    __table_args__ = (
 
        UniqueConstraint('ui_key'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 

	
 
    HOOK_UPDATE = 'changegroup.update'
 
    HOOK_REPO_SIZE = 'changegroup.repo_size'
 
    HOOK_PUSH = 'changegroup.push_logger'
 
    HOOK_PRE_PUSH = 'prechangegroup.pre_push'
 
    HOOK_PULL = 'outgoing.pull_logger'
 
    HOOK_PRE_PULL = 'preoutgoing.pre_pull'
 

	
 
    ui_id = Column("ui_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    ui_section = Column("ui_section", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    ui_key = Column("ui_key", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    ui_value = Column("ui_value", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    ui_active = Column("ui_active", Boolean(), nullable=True, unique=None, default=True)
 

	
 
    # def __init__(self, section='', key='', value=''):
 
    #     self.ui_section = section
 
    #     self.ui_key = key
 
    #     self.ui_value = value
 

	
 
    @classmethod
 
    def get_by_key(cls, key):
 
        return cls.query().filter(cls.ui_key == key).scalar()
 

	
 
    @classmethod
 
    def get_builtin_hooks(cls):
 
        q = cls.query()
 
        q = q.filter(cls.ui_key.in_([cls.HOOK_UPDATE, cls.HOOK_REPO_SIZE,
 
                                     cls.HOOK_PUSH, cls.HOOK_PRE_PUSH,
 
                                     cls.HOOK_PULL, cls.HOOK_PRE_PULL]))
 
        return q.all()
 

	
 
    @classmethod
 
    def get_custom_hooks(cls):
 
        q = cls.query()
 
        q = q.filter(~cls.ui_key.in_([cls.HOOK_UPDATE, cls.HOOK_REPO_SIZE,
 
                                      cls.HOOK_PUSH, cls.HOOK_PRE_PUSH,
 
                                      cls.HOOK_PULL, cls.HOOK_PRE_PULL]))
 
        q = q.filter(cls.ui_section == 'hooks')
 
        return q.all()
 

	
 
    @classmethod
 
    def get_repos_location(cls):
 
        return cls.get_by_key('/').ui_value
 

	
 
    @classmethod
 
    def create_or_update_hook(cls, key, val):
 
        new_ui = cls.get_by_key(key) or cls()
 
        new_ui.ui_section = 'hooks'
 
        new_ui.ui_active = True
 
        new_ui.ui_key = key
 
        new_ui.ui_value = val
 

	
 
        Session().add(new_ui)
 

	
 
    def __repr__(self):
 
        return '<%s[%s]%s=>%s]>' % (self.__class__.__name__, self.ui_section,
 
                                    self.ui_key, self.ui_value)
 

	
 

	
 
class User(Base, BaseModel):
 
    __tablename__ = 'users'
 
    __table_args__ = (
 
        UniqueConstraint('username'), UniqueConstraint('email'),
 
        Index('u_username_idx', 'username'),
 
        Index('u_email_idx', 'email'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    DEFAULT_USER = 'default'
 
    DEFAULT_GRAVATAR_URL = 'https://secure.gravatar.com/avatar/{md5email}?d=identicon&s={size}'
 

	
 
    user_id = Column("user_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    username = Column("username", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    password = Column("password", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    active = Column("active", Boolean(), nullable=True, unique=None, default=True)
 
    admin = Column("admin", Boolean(), nullable=True, unique=None, default=False)
 
    name = Column("firstname", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    lastname = Column("lastname", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    _email = Column("email", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    last_login = Column("last_login", DateTime(timezone=False), nullable=True, unique=None, default=None)
 
    extern_type = Column("extern_type", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    extern_name = Column("extern_name", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    api_key = Column("api_key", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    inherit_default_permissions = Column("inherit_default_permissions", Boolean(), nullable=False, unique=None, default=True)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    _user_data = Column("user_data", LargeBinary(), nullable=True)  # JSON data
 

	
 
    user_log = relationship('UserLog')
 
    user_perms = relationship('UserToPerm', primaryjoin="User.user_id==UserToPerm.user_id", cascade='all')
 

	
 
    repositories = relationship('Repository')
 
    user_followers = relationship('UserFollowing', primaryjoin='UserFollowing.follows_user_id==User.user_id', cascade='all')
 
    followings = relationship('UserFollowing', primaryjoin='UserFollowing.user_id==User.user_id', cascade='all')
 

	
 
    repo_to_perm = relationship('UserRepoToPerm', primaryjoin='UserRepoToPerm.user_id==User.user_id', cascade='all')
 
    repo_group_to_perm = relationship('UserRepoGroupToPerm', primaryjoin='UserRepoGroupToPerm.user_id==User.user_id', cascade='all')
 

	
 
    group_member = relationship('UserGroupMember', cascade='all')
 

	
 
    notifications = relationship('UserNotification', cascade='all')
 
    # notifications assigned to this user
 
    user_created_notifications = relationship('Notification', cascade='all')
 
    # comments created by this user
 
    user_comments = relationship('ChangesetComment', cascade='all')
 
    #extra emails for this user
 
    user_emails = relationship('UserEmailMap', cascade='all')
 
    #extra api keys
 
    user_api_keys = relationship('UserApiKeys', cascade='all')
 

	
 

	
 
    @hybrid_property
 
    def email(self):
 
        return self._email
 

	
 
    @email.setter
 
    def email(self, val):
 
        self._email = val.lower() if val else None
 

	
 
    @property
 
    def firstname(self):
 
        # alias for future
 
        return self.name
 

	
 
    @property
 
    def emails(self):
 
        other = UserEmailMap.query().filter(UserEmailMap.user==self).all()
 
        return [self.email] + [x.email for x in other]
 

	
 
    @property
 
    def api_keys(self):
 
        other = UserApiKeys.query().filter(UserApiKeys.user==self).all()
 
        return [self.api_key] + [x.api_key for x in other]
 

	
 
    @property
 
    def ip_addresses(self):
 
        ret = UserIpMap.query().filter(UserIpMap.user == self).all()
 
        return [x.ip_addr for x in ret]
 

	
 
    @property
 
    def username_and_name(self):
 
        return '%s (%s %s)' % (self.username, self.firstname, self.lastname)
 

	
 
    @property
 
    def full_name(self):
 
        return '%s %s' % (self.firstname, self.lastname)
 

	
 
    @property
 
    def full_name_or_username(self):
 
        return ('%s %s' % (self.firstname, self.lastname)
 
                if (self.firstname and self.lastname) else self.username)
 

	
 
    @property
 
    def full_contact(self):
 
        return '%s %s <%s>' % (self.firstname, self.lastname, self.email)
 

	
 
    @property
 
    def short_contact(self):
 
        return '%s %s' % (self.firstname, self.lastname)
 

	
 
    @property
 
    def is_admin(self):
 
        return self.admin
 

	
 
    @property
 
    def AuthUser(self):
 
        """
 
        Returns instance of AuthUser for this user
 
        """
 
        from kallithea.lib.auth import AuthUser
 
        return AuthUser(user_id=self.user_id, api_key=self.api_key,
 
                        username=self.username)
 

	
 
    @hybrid_property
 
    def user_data(self):
 
        if not self._user_data:
 
            return {}
 

	
 
        try:
 
            return json.loads(self._user_data)
 
        except TypeError:
 
            return {}
 

	
 
    @user_data.setter
 
    def user_data(self, val):
 
        try:
 
            self._user_data = json.dumps(val)
 
        except Exception:
 
            log.error(traceback.format_exc())
 

	
 
    def __unicode__(self):
 
        return u"<%s('id:%s:%s')>" % (self.__class__.__name__,
 
                                      self.user_id, self.username)
 

	
 
    @classmethod
 
    def get_by_username(cls, username, case_insensitive=False, cache=False):
 
        if case_insensitive:
 
            q = cls.query().filter(cls.username.ilike(username))
 
        else:
 
            q = cls.query().filter(cls.username == username)
 

	
 
        if cache:
 
            q = q.options(FromCache(
 
                            "sql_cache_short",
 
                            "get_user_%s" % _hash_key(username)
 
                          )
 
            )
 
        return q.scalar()
 

	
 
    @classmethod
 
    def get_by_api_key(cls, api_key, cache=False, fallback=True):
 
        q = cls.query().filter(cls.api_key == api_key)
 

	
 
        if cache:
 
            q = q.options(FromCache("sql_cache_short",
 
                                    "get_api_key_%s" % api_key))
 
        res = q.scalar()
 

	
 
        if fallback and not res:
 
            #fallback to additional keys
 
            _res = UserApiKeys.query()\
 
                .filter(UserApiKeys.api_key == api_key)\
 
                .filter(or_(UserApiKeys.expires == -1,
 
                            UserApiKeys.expires >= time.time()))\
 
                .first()
 
            if _res:
 
                res = _res.user
 
        return res
 

	
 
    @classmethod
 
    def get_by_email(cls, email, case_insensitive=False, cache=False):
 
        if case_insensitive:
 
            q = cls.query().filter(cls.email.ilike(email))
 
        else:
 
            q = cls.query().filter(cls.email == email)
 

	
 
        if cache:
 
            q = q.options(FromCache("sql_cache_short",
 
                                    "get_email_key_%s" % email))
 

	
 
        ret = q.scalar()
 
        if ret is None:
 
            q = UserEmailMap.query()
 
            # try fetching in alternate email map
 
            if case_insensitive:
 
                q = q.filter(UserEmailMap.email.ilike(email))
 
            else:
 
                q = q.filter(UserEmailMap.email == email)
 
            q = q.options(joinedload(UserEmailMap.user))
 
            if cache:
 
                q = q.options(FromCache("sql_cache_short",
 
                                        "get_email_map_key_%s" % email))
 
            ret = getattr(q.scalar(), 'user', None)
 

	
 
        return ret
 

	
 
    @classmethod
 
    def get_from_cs_author(cls, author):
 
        """
 
        Tries to get User objects out of commit author string
 

	
 
        :param author:
 
        """
 
        from kallithea.lib.helpers import email, author_name
 
        # Valid email in the attribute passed, see if they're in the system
 
        _email = email(author)
 
        if _email:
 
            user = cls.get_by_email(_email, case_insensitive=True)
 
            if user:
 
                return user
 
        # Maybe we can match by username?
 
        _author = author_name(author)
 
        user = cls.get_by_username(_author, case_insensitive=True)
 
        if user:
 
            return user
 

	
 
    def update_lastlogin(self):
 
        """Update user lastlogin"""
 
        self.last_login = datetime.datetime.now()
 
        Session().add(self)
 
        log.debug('updated user %s lastlogin' % self.username)
 

	
 
    @classmethod
 
    def get_first_admin(cls):
 
        user = User.query().filter(User.admin == True).first()
 
        if user is None:
 
            raise Exception('Missing administrative account!')
 
        return user
 

	
 
    @classmethod
 
    def get_default_user(cls, cache=False):
 
        user = User.get_by_username(User.DEFAULT_USER, cache=cache)
 
        if user is None:
 
            raise Exception('Missing default account!')
 
        return user
 

	
 
    def get_api_data(self):
 
        """
 
        Common function for generating user related data for API
 
        """
 
        user = self
 
        data = dict(
 
            user_id=user.user_id,
 
            username=user.username,
 
            firstname=user.name,
 
            lastname=user.lastname,
 
            email=user.email,
 
            emails=user.emails,
 
            api_key=user.api_key,
 
            api_keys=user.api_keys,
 
            active=user.active,
 
            admin=user.admin,
 
            extern_type=user.extern_type,
 
            extern_name=user.extern_name,
 
            last_login=user.last_login,
 
            ip_addresses=user.ip_addresses
 
        )
 
        return data
 

	
 
    def __json__(self):
 
        data = dict(
 
            full_name=self.full_name,
 
            full_name_or_username=self.full_name_or_username,
 
            short_contact=self.short_contact,
 
            full_contact=self.full_contact
 
        )
 
        data.update(self.get_api_data())
 
        return data
 

	
 

	
 
class UserApiKeys(Base, BaseModel):
 
    __tablename__ = 'user_api_keys'
 
    __table_args__ = (
 
        Index('uak_api_key_idx', 'api_key'),
 
        Index('uak_api_key_expires_idx', 'api_key', 'expires'),
 
        UniqueConstraint('api_key'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    __mapper_args__ = {}
 

	
 
    user_api_key_id = Column("user_api_key_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=True, unique=None, default=None)
 
    api_key = Column("api_key", String(255, convert_unicode=False), nullable=False, unique=True)
 
    description = Column('description', UnicodeText(1024))
 
    expires = Column('expires', Float(53), nullable=False)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 

	
 
    user = relationship('User', lazy='joined')
 

	
 
    @property
 
    def expired(self):
 
        if self.expires == -1:
 
            return False
 
        return time.time() > self.expires
 

	
 

	
 
class UserEmailMap(Base, BaseModel):
 
    __tablename__ = 'user_email_map'
 
    __table_args__ = (
 
        Index('uem_email_idx', 'email'),
 
        UniqueConstraint('email'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    __mapper_args__ = {}
 

	
 
    email_id = Column("email_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=True, unique=None, default=None)
 
    _email = Column("email", String(255, convert_unicode=False), nullable=True, unique=False, default=None)
 
    user = relationship('User', lazy='joined')
 

	
 
    @validates('_email')
 
    def validate_email(self, key, email):
 
        # check if this email is not main one
 
        main_email = Session().query(User).filter(User.email == email).scalar()
 
        if main_email is not None:
 
            raise AttributeError('email %s is present is user table' % email)
 
        return email
 

	
 
    @hybrid_property
 
    def email(self):
 
        return self._email
 

	
 
    @email.setter
 
    def email(self, val):
 
        self._email = val.lower() if val else None
 

	
 

	
 
class UserIpMap(Base, BaseModel):
 
    __tablename__ = 'user_ip_map'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'ip_addr'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    __mapper_args__ = {}
 

	
 
    ip_id = Column("ip_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=True, unique=None, default=None)
 
    ip_addr = Column("ip_addr", String(255, convert_unicode=False), nullable=True, unique=False, default=None)
 
    active = Column("active", Boolean(), nullable=True, unique=None, default=True)
 
    user = relationship('User', lazy='joined')
 

	
 
    @classmethod
 
    def _get_ip_range(cls, ip_addr):
 
        from kallithea.lib import ipaddr
 
        net = ipaddr.IPNetwork(address=ip_addr)
 
        return [str(net.network), str(net.broadcast)]
 

	
 
    def __json__(self):
 
        return dict(
 
          ip_addr=self.ip_addr,
 
          ip_range=self._get_ip_range(self.ip_addr)
 
        )
 

	
 
    def __unicode__(self):
 
        return u"<%s('user_id:%s=>%s')>" % (self.__class__.__name__,
 
                                            self.user_id, self.ip_addr)
 

	
 
class UserLog(Base, BaseModel):
 
    __tablename__ = 'user_logs'
 
    __table_args__ = (
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    user_log_id = Column("user_log_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=True, unique=None, default=None)
 
    username = Column("username", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=True)
 
    repository_name = Column("repository_name", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    user_ip = Column("user_ip", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    action = Column("action", UnicodeText(1200000, convert_unicode=False), nullable=True, unique=None, default=None)
 
    action_date = Column("action_date", DateTime(timezone=False), nullable=True, unique=None, default=None)
 

	
 
    def __unicode__(self):
 
        return u"<%s('id:%s:%s')>" % (self.__class__.__name__,
 
                                      self.repository_name,
 
                                      self.action)
 

	
 
    @property
 
    def action_as_day(self):
 
        return datetime.date(*self.action_date.timetuple()[:3])
 

	
 
    user = relationship('User')
 
    repository = relationship('Repository', cascade='')
 

	
 

	
 
class UserGroup(Base, BaseModel):
 
    __tablename__ = 'users_groups'
 
    __table_args__ = (
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 

	
 
    users_group_id = Column("users_group_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    users_group_name = Column("users_group_name", String(255, convert_unicode=False), nullable=False, unique=True, default=None)
 
    user_group_description = Column("user_group_description", String(10000, convert_unicode=False), nullable=True, unique=None, default=None)
 
    users_group_active = Column("users_group_active", Boolean(), nullable=True, unique=None, default=None)
 
    inherit_default_permissions = Column("users_group_inherit_default_permissions", Boolean(), nullable=False, unique=None, default=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=False, default=None)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    _group_data = Column("group_data", LargeBinary(), nullable=True)  # JSON data
 

	
 
    members = relationship('UserGroupMember', cascade="all, delete-orphan", lazy="joined")
 
    users_group_to_perm = relationship('UserGroupToPerm', cascade='all')
 
    users_group_repo_to_perm = relationship('UserGroupRepoToPerm', cascade='all')
 
    users_group_repo_group_to_perm = relationship('UserGroupRepoGroupToPerm', cascade='all')
 
    user_user_group_to_perm = relationship('UserUserGroupToPerm ', cascade='all')
 
    user_group_user_group_to_perm = relationship('UserGroupUserGroupToPerm ', primaryjoin="UserGroupUserGroupToPerm.target_user_group_id==UserGroup.users_group_id", cascade='all')
 

	
 
    user = relationship('User')
 

	
 
    @hybrid_property
 
    def group_data(self):
 
        if not self._group_data:
 
            return {}
 

	
 
        try:
 
            return json.loads(self._group_data)
 
        except TypeError:
 
            return {}
 

	
 
    @group_data.setter
 
    def group_data(self, val):
 
        try:
 
            self._group_data = json.dumps(val)
 
        except Exception:
 
            log.error(traceback.format_exc())
 

	
 
    def __unicode__(self):
 
        return u"<%s('id:%s:%s')>" % (self.__class__.__name__,
 
                                      self.users_group_id,
 
                                      self.users_group_name)
 

	
 
    @classmethod
 
    def get_by_group_name(cls, group_name, cache=False,
 
                          case_insensitive=False):
 
        if case_insensitive:
 
            q = cls.query().filter(cls.users_group_name.ilike(group_name))
 
        else:
 
            q = cls.query().filter(cls.users_group_name == group_name)
 
        if cache:
 
            q = q.options(FromCache(
 
                            "sql_cache_short",
 
                            "get_user_%s" % _hash_key(group_name)
 
                          )
 
            )
 
        return q.scalar()
 

	
 
    @classmethod
 
    def get(cls, user_group_id, cache=False):
 
        user_group = cls.query()
 
        if cache:
 
            user_group = user_group.options(FromCache("sql_cache_short",
 
                                    "get_users_group_%s" % user_group_id))
 
        return user_group.get(user_group_id)
 

	
 
    def get_api_data(self, with_members=True):
 
        user_group = self
 

	
 
        data = dict(
 
            users_group_id=user_group.users_group_id,
 
            group_name=user_group.users_group_name,
 
            group_description=user_group.user_group_description,
 
            active=user_group.users_group_active,
 
            owner=user_group.user.username,
 
        )
 
        if with_members:
 
            members = []
 
            for user in user_group.members:
 
                user = user.user
 
                members.append(user.get_api_data())
 
            data['members'] = members
 

	
 
        return data
 

	
 

	
 
class UserGroupMember(Base, BaseModel):
 
    __tablename__ = 'users_groups_members'
 
    __table_args__ = (
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 

	
 
    users_group_member_id = Column("users_group_member_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relationship('User', lazy='joined')
 
    users_group = relationship('UserGroup')
 

	
 
    def __init__(self, gr_id='', u_id=''):
 
        self.users_group_id = gr_id
 
        self.user_id = u_id
 

	
 

	
 
class RepositoryField(Base, BaseModel):
 
    __tablename__ = 'repositories_fields'
 
    __table_args__ = (
 
        UniqueConstraint('repository_id', 'field_key'),  # no-multi field
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    PREFIX = 'ex_'  # prefix used in form to not conflict with already existing fields
 

	
 
    repo_field_id = Column("repo_field_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=None, default=None)
 
    field_key = Column("field_key", String(250, convert_unicode=False))
 
    field_label = Column("field_label", String(1024, convert_unicode=False), nullable=False)
 
    field_value = Column("field_value", String(10000, convert_unicode=False), nullable=False)
 
    field_desc = Column("field_desc", String(1024, convert_unicode=False), nullable=False)
 
    field_type = Column("field_type", String(256), nullable=False, unique=None)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 

	
 
    repository = relationship('Repository')
 

	
 
    @property
 
    def field_key_prefixed(self):
 
        return 'ex_%s' % self.field_key
 

	
 
    @classmethod
 
    def un_prefix_key(cls, key):
 
        if key.startswith(cls.PREFIX):
 
            return key[len(cls.PREFIX):]
 
        return key
 

	
 
    @classmethod
 
    def get_by_key_name(cls, key, repo):
 
        row = cls.query()\
 
                .filter(cls.repository == repo)\
 
                .filter(cls.field_key == key).scalar()
 
        return row
 

	
 

	
 
class Repository(Base, BaseModel):
 
    __tablename__ = 'repositories'
 
    __table_args__ = (
 
        UniqueConstraint('repo_name'),
 
        Index('r_repo_name_idx', 'repo_name'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    DEFAULT_CLONE_URI = '{scheme}://{user}@{netloc}/{repo}'
 
    DEFAULT_CLONE_URI_ID = '{scheme}://{user}@{netloc}/_{repoid}'
 

	
 
    STATE_CREATED = 'repo_state_created'
 
    STATE_PENDING = 'repo_state_pending'
 
    STATE_ERROR = 'repo_state_error'
 

	
 
    repo_id = Column("repo_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    repo_name = Column("repo_name", String(255, convert_unicode=False), nullable=False, unique=True, default=None)
 
    repo_state = Column("repo_state", String(255), nullable=True)
 

	
 
    clone_uri = Column("clone_uri", String(255, convert_unicode=False), nullable=True, unique=False, default=None)
 
    repo_type = Column("repo_type", String(255, convert_unicode=False), nullable=False, unique=False, default=None)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=False, default=None)
 
    private = Column("private", Boolean(), nullable=True, unique=None, default=None)
 
    enable_statistics = Column("statistics", Boolean(), nullable=True, unique=None, default=True)
 
    enable_downloads = Column("downloads", Boolean(), nullable=True, unique=None, default=True)
 
    description = Column("description", String(10000, convert_unicode=False), nullable=True, unique=None, default=None)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
 
    updated_on = Column('updated_on', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
 
    _landing_revision = Column("landing_revision", String(255, convert_unicode=False), nullable=False, unique=False, default=None)
 
    enable_locking = Column("enable_locking", Boolean(), nullable=False, unique=None, default=False)
 
    _locked = Column("locked", String(255, convert_unicode=False), nullable=True, unique=False, default=None)
 
    _changeset_cache = Column("changeset_cache", LargeBinary(), nullable=True) #JSON data
 

	
 
    fork_id = Column("fork_id", Integer(), ForeignKey('repositories.repo_id'), nullable=True, unique=False, default=None)
 
    group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'), nullable=True, unique=False, default=None)
 

	
 
    user = relationship('User')
 
    fork = relationship('Repository', remote_side=repo_id)
 
    group = relationship('RepoGroup')
 
    repo_to_perm = relationship('UserRepoToPerm', cascade='all', order_by='UserRepoToPerm.repo_to_perm_id')
 
    users_group_to_perm = relationship('UserGroupRepoToPerm', cascade='all')
 
    stats = relationship('Statistics', cascade='all', uselist=False)
 

	
 
    followers = relationship('UserFollowing',
 
                             primaryjoin='UserFollowing.follows_repo_id==Repository.repo_id',
 
                             cascade='all')
 
    extra_fields = relationship('RepositoryField',
 
                                cascade="all, delete-orphan")
 

	
 
    logs = relationship('UserLog')
 
    comments = relationship('ChangesetComment', cascade="all, delete-orphan")
 

	
 
    pull_requests_org = relationship('PullRequest',
 
                    primaryjoin='PullRequest.org_repo_id==Repository.repo_id',
 
                    cascade="all, delete-orphan")
 

	
 
    pull_requests_other = relationship('PullRequest',
 
                    primaryjoin='PullRequest.other_repo_id==Repository.repo_id',
 
                    cascade="all, delete-orphan")
 

	
 
    def __unicode__(self):
 
        return u"<%s('%s:%s')>" % (self.__class__.__name__, self.repo_id,
 
                                   safe_unicode(self.repo_name))
 

	
 
    @hybrid_property
 
    def landing_rev(self):
 
        # always should return [rev_type, rev]
 
        if self._landing_revision:
 
            _rev_info = self._landing_revision.split(':')
 
            if len(_rev_info) < 2:
 
                _rev_info.insert(0, 'rev')
 
            return [_rev_info[0], _rev_info[1]]
 
        return [None, None]
 

	
 
    @landing_rev.setter
 
    def landing_rev(self, val):
 
        if ':' not in val:
 
            raise ValueError('value must be delimited with `:` and consist '
 
                             'of <rev_type>:<rev>, got %s instead' % val)
 
        self._landing_revision = val
 

	
 
    @hybrid_property
 
    def locked(self):
 
        # always should return [user_id, timelocked]
 
        if self._locked:
 
            _lock_info = self._locked.split(':')
 
            return int(_lock_info[0]), _lock_info[1]
 
        return [None, None]
 

	
 
    @locked.setter
 
    def locked(self, val):
 
        if val and isinstance(val, (list, tuple)):
 
            self._locked = ':'.join(map(str, val))
 
        else:
 
            self._locked = None
 

	
 
    @hybrid_property
 
    def changeset_cache(self):
 
        from kallithea.lib.vcs.backends.base import EmptyChangeset
 
        dummy = EmptyChangeset().__json__()
 
        if not self._changeset_cache:
 
            return dummy
 
        try:
 
            return json.loads(self._changeset_cache)
 
        except TypeError:
 
            return dummy
 

	
 
    @changeset_cache.setter
 
    def changeset_cache(self, val):
 
        try:
 
            self._changeset_cache = json.dumps(val)
 
        except Exception:
 
            log.error(traceback.format_exc())
 

	
 
    @classmethod
 
    def url_sep(cls):
 
        return URL_SEP
 

	
 
    @classmethod
 
    def normalize_repo_name(cls, repo_name):
 
        """
 
        Normalizes os specific repo_name to the format internally stored inside
 
        dabatabase using URL_SEP
 

	
 
        :param cls:
 
        :param repo_name:
 
        """
 
        return cls.url_sep().join(repo_name.split(os.sep))
 

	
 
    @classmethod
 
    def get_by_repo_name(cls, repo_name):
 
        q = Session().query(cls).filter(cls.repo_name == repo_name)
 
        q = q.options(joinedload(Repository.fork))\
 
                .options(joinedload(Repository.user))\
 
                .options(joinedload(Repository.group))
 
        return q.scalar()
 

	
 
    @classmethod
 
    def get_by_full_path(cls, repo_full_path):
 
        repo_name = repo_full_path.split(cls.base_path(), 1)[-1]
 
        repo_name = cls.normalize_repo_name(repo_name)
 
        return cls.get_by_repo_name(repo_name.strip(URL_SEP))
 

	
 
    @classmethod
 
    def get_repo_forks(cls, repo_id):
 
        return cls.query().filter(Repository.fork_id == repo_id)
 

	
 
    @classmethod
 
    def base_path(cls):
 
        """
 
        Returns base path where all repos are stored
 

	
 
        :param cls:
 
        """
 
        q = Session().query(Ui)\
 
            .filter(Ui.ui_key == cls.url_sep())
 
        q = q.options(FromCache("sql_cache_short", "repository_repo_path"))
 
        return q.one().ui_value
 

	
 
    @property
 
    def forks(self):
 
        """
 
        Return forks of this repo
 
        """
 
        return Repository.get_repo_forks(self.repo_id)
 

	
 
    @property
 
    def parent(self):
 
        """
 
        Returns fork parent
 
        """
 
        return self.fork
 

	
 
    @property
 
    def just_name(self):
 
        return self.repo_name.split(Repository.url_sep())[-1]
 

	
 
    @property
 
    def groups_with_parents(self):
 
        groups = []
 
        if self.group is None:
 
            return groups
 

	
 
        cur_gr = self.group
 
        groups.insert(0, cur_gr)
 
        while 1:
 
            gr = getattr(cur_gr, 'parent_group', None)
 
            cur_gr = cur_gr.parent_group
 
            if gr is None:
 
                break
 
            groups.insert(0, gr)
 

	
 
        return groups
 

	
 
    @property
 
    def groups_and_repo(self):
 
        return self.groups_with_parents, self.just_name, self.repo_name
 

	
 
    @LazyProperty
 
    def repo_path(self):
 
        """
 
        Returns base full path for that repository means where it actually
 
        exists on a filesystem
 
        """
 
        q = Session().query(Ui).filter(Ui.ui_key ==
 
                                              Repository.url_sep())
 
        q = q.options(FromCache("sql_cache_short", "repository_repo_path"))
 
        return q.one().ui_value
 

	
 
    @property
 
    def repo_full_path(self):
 
        p = [self.repo_path]
 
        # we need to split the name by / since this is how we store the
 
        # names in the database, but that eventually needs to be converted
 
        # into a valid system path
 
        p += self.repo_name.split(Repository.url_sep())
 
        return os.path.join(*map(safe_unicode, p))
 

	
 
    @property
 
    def cache_keys(self):
 
        """
 
        Returns associated cache keys for that repo
 
        """
 
        return CacheInvalidation.query()\
 
            .filter(CacheInvalidation.cache_args == self.repo_name)\
 
            .order_by(CacheInvalidation.cache_key)\
 
            .all()
 

	
 
    def get_new_name(self, repo_name):
 
        """
 
        returns new full repository name based on assigned group and new new
 

	
 
        :param group_name:
 
        """
 
        path_prefix = self.group.full_path_splitted if self.group else []
 
        return Repository.url_sep().join(path_prefix + [repo_name])
 

	
 
    @property
 
    def _ui(self):
 
        """
 
        Creates an db based ui object for this repository
 
        """
 
        from kallithea.lib.utils import make_ui
 
        return make_ui('db', clear_session=False)
 

	
 
    @classmethod
 
    def is_valid(cls, repo_name):
 
        """
 
        returns True if given repo name is a valid filesystem repository
 

	
 
        :param cls:
 
        :param repo_name:
 
        """
 
        from kallithea.lib.utils import is_valid_repo
 

	
 
        return is_valid_repo(repo_name, cls.base_path())
 

	
 
    def get_api_data(self):
 
        """
 
        Common function for generating repo api data
 

	
 
        """
 
        repo = self
 
        data = dict(
 
            repo_id=repo.repo_id,
 
            repo_name=repo.repo_name,
 
            repo_type=repo.repo_type,
 
            clone_uri=repo.clone_uri,
 
            private=repo.private,
 
            created_on=repo.created_on,
 
            description=repo.description,
 
            landing_rev=repo.landing_rev,
 
            owner=repo.user.username,
 
            fork_of=repo.fork.repo_name if repo.fork else None,
 
            enable_statistics=repo.enable_statistics,
 
            enable_locking=repo.enable_locking,
 
            enable_downloads=repo.enable_downloads,
 
            last_changeset=repo.changeset_cache,
 
            locked_by=User.get(self.locked[0]).get_api_data() \
 
                if self.locked[0] else None,
 
            locked_date=time_to_datetime(self.locked[1]) \
 
                if self.locked[1] else None
 
        )
 
        rc_config = Setting.get_app_settings()
 
        repository_fields = str2bool(rc_config.get('repository_fields'))
 
        if repository_fields:
 
            for f in self.extra_fields:
 
                data[f.field_key_prefixed] = f.field_value
 

	
 
        return data
 

	
 
    @classmethod
 
    def lock(cls, repo, user_id, lock_time=None):
 
        if not lock_time:
 
            lock_time = time.time()
 
        repo.locked = [user_id, lock_time]
 
        Session().add(repo)
 
        Session().commit()
 

	
 
    @classmethod
 
    def unlock(cls, repo):
 
        repo.locked = None
 
        Session().add(repo)
 
        Session().commit()
 

	
 
    @classmethod
 
    def getlock(cls, repo):
 
        return repo.locked
 

	
 
    @property
 
    def last_db_change(self):
 
        return self.updated_on
 

	
 
    @property
 
    def clone_uri_hidden(self):
 
        clone_uri = self.clone_uri
 
        if clone_uri:
 
            import urlobject
 
            url_obj = urlobject.URLObject(self.clone_uri)
 
            if url_obj.password:
 
                clone_uri = url_obj.with_password('*****')
 
        return clone_uri
 

	
 
    def clone_url(self, **override):
 
        import kallithea.lib.helpers as h
 
        qualified_home_url = h.canonical_url('home')
 

	
 
        uri_tmpl = None
 
        if 'with_id' in override:
 
            uri_tmpl = self.DEFAULT_CLONE_URI_ID
 
            del override['with_id']
 

	
 
        if 'uri_tmpl' in override:
 
            uri_tmpl = override['uri_tmpl']
 
            del override['uri_tmpl']
 

	
 
        # we didn't override our tmpl from **overrides
 
        if not uri_tmpl:
 
            uri_tmpl = self.DEFAULT_CLONE_URI
 
            try:
 
                from pylons import tmpl_context as c
 
                uri_tmpl = c.clone_uri_tmpl
 
            except Exception:
 
            except AttributeError:
 
                # in any case if we call this outside of request context,
 
                # ie, not having tmpl_context set up
 
                pass
 

	
 
        return get_clone_url(uri_tmpl=uri_tmpl,
 
                             qualifed_home_url=qualified_home_url,
 
                             repo_name=self.repo_name,
 
                             repo_id=self.repo_id, **override)
 

	
 
    def set_state(self, state):
 
        self.repo_state = state
 
        Session().add(self)
 
    #==========================================================================
 
    # SCM PROPERTIES
 
    #==========================================================================
 

	
 
    def get_changeset(self, rev=None):
 
        return get_changeset_safe(self.scm_instance, rev)
 

	
 
    def get_landing_changeset(self):
 
        """
 
        Returns landing changeset, or if that doesn't exist returns the tip
 
        """
 
        _rev_type, _rev = self.landing_rev
 
        cs = self.get_changeset(_rev)
 
        if isinstance(cs, EmptyChangeset):
 
            return self.get_changeset()
 
        return cs
 

	
 
    def update_changeset_cache(self, cs_cache=None):
 
        """
 
        Update cache of last changeset for repository, keys should be::
 

	
 
            short_id
 
            raw_id
 
            revision
 
            message
 
            date
 
            author
 

	
 
        :param cs_cache:
 
        """
 
        from kallithea.lib.vcs.backends.base import BaseChangeset
 
        if cs_cache is None:
 
            cs_cache = EmptyChangeset()
 
            # use no-cache version here
 
            scm_repo = self.scm_instance_no_cache()
 
            if scm_repo:
 
                cs_cache = scm_repo.get_changeset()
 

	
 
        if isinstance(cs_cache, BaseChangeset):
 
            cs_cache = cs_cache.__json__()
 

	
 
        if (not self.changeset_cache or cs_cache['raw_id'] != self.changeset_cache['raw_id']):
 
            _default = datetime.datetime.fromtimestamp(0)
 
            last_change = cs_cache.get('date') or _default
 
            log.debug('updated repo %s with new cs cache %s'
 
                      % (self.repo_name, cs_cache))
 
            self.updated_on = last_change
 
            self.changeset_cache = cs_cache
 
            Session().add(self)
 
            Session().commit()
 
        else:
 
            log.debug('changeset_cache for %s already up to date with %s'
 
                      % (self.repo_name, cs_cache['raw_id']))
 

	
 
    @property
 
    def tip(self):
 
        return self.get_changeset('tip')
 

	
 
    @property
 
    def author(self):
 
        return self.tip.author
 

	
 
    @property
 
    def last_change(self):
 
        return self.scm_instance.last_change
 

	
 
    def get_comments(self, revisions=None):
 
        """
 
        Returns comments for this repository grouped by revisions
 

	
 
        :param revisions: filter query by revisions only
 
        """
 
        cmts = ChangesetComment.query()\
 
            .filter(ChangesetComment.repo == self)
 
        if revisions:
 
            cmts = cmts.filter(ChangesetComment.revision.in_(revisions))
 
        grouped = collections.defaultdict(list)
 
        for cmt in cmts.all():
 
            grouped[cmt.revision].append(cmt)
 
        return grouped
 

	
 
    def statuses(self, revisions):
 
        """
 
        Returns statuses for this repository
 

	
 
        :param revisions: list of revisions to get statuses for
 
        """
 
        if not revisions:
 
            return {}
 

	
 
        statuses = ChangesetStatus.query()\
 
            .filter(ChangesetStatus.repo == self)\
 
            .filter(ChangesetStatus.version == 0)\
 
            .filter(ChangesetStatus.revision.in_(revisions))
 
        grouped = {}
 

	
 
        stat = ChangesetStatus.DEFAULT
 
        status_lbl = ChangesetStatus.get_status_lbl(stat)
 
        for pr in PullRequest.query().filter(PullRequest.org_repo == self).all():
 
            for rev in pr.revisions:
 
                pr_id = pr.pull_request_id
 
                pr_repo = pr.other_repo.repo_name
 
                grouped[rev] = [stat, status_lbl, pr_id, pr_repo]
 

	
 
        for stat in statuses.all():
 
            pr_id = pr_repo = None
 
            if stat.pull_request:
 
                pr_id = stat.pull_request.pull_request_id
 
                pr_repo = stat.pull_request.other_repo.repo_name
 
            grouped[stat.revision] = [str(stat.status), stat.status_lbl,
 
                                      pr_id, pr_repo]
 
        return grouped
 

	
 
    def _repo_size(self):
 
        from kallithea.lib import helpers as h
 
        log.debug('calculating repository size...')
 
        return h.format_byte_size(self.scm_instance.size)
 

	
 
    #==========================================================================
 
    # SCM CACHE INSTANCE
 
    #==========================================================================
 

	
 
    def set_invalidate(self):
 
        """
 
        Mark caches of this repo as invalid.
 
        """
 
        CacheInvalidation.set_invalidate(self.repo_name)
 

	
 
    def scm_instance_no_cache(self):
 
        return self.__get_instance()
 

	
 
    @property
 
    def scm_instance(self):
 
        import kallithea
 
        full_cache = str2bool(kallithea.CONFIG.get('vcs_full_cache'))
 
        if full_cache:
 
            return self.scm_instance_cached()
 
        return self.__get_instance()
 

	
 
    def scm_instance_cached(self, valid_cache_keys=None):
 
        @cache_region('long_term')
 
        def _c(repo_name):
 
            return self.__get_instance()
 
        rn = self.repo_name
 

	
 
        valid = CacheInvalidation.test_and_set_valid(rn, None, valid_cache_keys=valid_cache_keys)
 
        if not valid:
 
            log.debug('Cache for %s invalidated, getting new object' % (rn))
 
            region_invalidate(_c, None, rn)
 
        else:
 
            log.debug('Getting scm_instance of %s from cache' % (rn))
 
        return _c(rn)
 

	
 
    def __get_instance(self):
 
        repo_full_path = self.repo_full_path
 

	
 
        alias = get_scm(repo_full_path)[0]
 
        log.debug('Creating instance of %s repository from %s'
 
                  % (alias, repo_full_path))
 
        backend = get_backend(alias)
 

	
 
        if alias == 'hg':
 
            repo = backend(safe_str(repo_full_path), create=False,
 
                           baseui=self._ui)
 
        else:
 
            repo = backend(repo_full_path, create=False)
 

	
 
        return repo
 

	
 
    def __json__(self):
 
        return dict(landing_rev = self.landing_rev)
 

	
 
class RepoGroup(Base, BaseModel):
 
    __tablename__ = 'groups'
 
    __table_args__ = (
 
        UniqueConstraint('group_name', 'group_parent_id'),
 
        CheckConstraint('group_id != group_parent_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    __mapper_args__ = {'order_by': 'group_name'}
 

	
 
    SEP = ' &raquo; '
 

	
 
    group_id = Column("group_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    group_name = Column("group_name", String(255, convert_unicode=False), nullable=False, unique=True, default=None)
 
    group_parent_id = Column("group_parent_id", Integer(), ForeignKey('groups.group_id'), nullable=True, unique=None, default=None)
 
    group_description = Column("group_description", String(10000, convert_unicode=False), nullable=True, unique=None, default=None)
 
    enable_locking = Column("enable_locking", Boolean(), nullable=False, unique=None, default=False)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=False, default=None)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 

	
 
    repo_group_to_perm = relationship('UserRepoGroupToPerm', cascade='all', order_by='UserRepoGroupToPerm.group_to_perm_id')
 
    users_group_to_perm = relationship('UserGroupRepoGroupToPerm', cascade='all')
 
    parent_group = relationship('RepoGroup', remote_side=group_id)
 
    user = relationship('User')
 

	
 
    def __init__(self, group_name='', parent_group=None):
 
        self.group_name = group_name
 
        self.parent_group = parent_group
 

	
 
    def __unicode__(self):
 
        return u"<%s('id:%s:%s')>" % (self.__class__.__name__, self.group_id,
 
                                      self.group_name)
 

	
 
    @classmethod
 
    def _generate_choice(cls, repo_group):
 
        from webhelpers.html import literal as _literal
 
        _name = lambda k: _literal(cls.SEP.join(k))
 
        return repo_group.group_id, _name(repo_group.full_path_splitted)
 

	
 
    @classmethod
 
    def groups_choices(cls, groups=None, show_empty_group=True):
 
        if not groups:
 
            groups = cls.query().all()
 

	
 
        repo_groups = []
 
        if show_empty_group:
 
            repo_groups = [('-1', u'-- %s --' % _('top level'))]
 

	
 
        repo_groups.extend([cls._generate_choice(x) for x in groups])
 

	
 
        repo_groups = sorted(repo_groups, key=lambda t: t[1].split(cls.SEP)[0])
 
        return repo_groups
 

	
 
    @classmethod
 
    def url_sep(cls):
 
        return URL_SEP
 

	
 
    @classmethod
 
    def get_by_group_name(cls, group_name, cache=False, case_insensitive=False):
 
        if case_insensitive:
 
            gr = cls.query()\
 
                .filter(cls.group_name.ilike(group_name))
 
        else:
 
            gr = cls.query()\
 
                .filter(cls.group_name == group_name)
 
        if cache:
 
            gr = gr.options(FromCache(
 
                            "sql_cache_short",
 
                            "get_group_%s" % _hash_key(group_name)
 
                            )
 
            )
 
        return gr.scalar()
 

	
 
    @property
 
    def parents(self):
 
        parents_recursion_limit = 10
 
        groups = []
 
        if self.parent_group is None:
 
            return groups
 
        cur_gr = self.parent_group
 
        groups.insert(0, cur_gr)
 
        cnt = 0
 
        while 1:
 
            cnt += 1
 
            gr = getattr(cur_gr, 'parent_group', None)
 
            cur_gr = cur_gr.parent_group
 
            if gr is None:
 
                break
 
            if cnt == parents_recursion_limit:
 
                # this will prevent accidental infinit loops
 
                log.error(('more than %s parents found for group %s, stopping '
 
                           'recursive parent fetching' % (parents_recursion_limit, self)))
 
                break
 

	
 
            groups.insert(0, gr)
 
        return groups
 

	
 
    @property
 
    def children(self):
 
        return RepoGroup.query().filter(RepoGroup.parent_group == self)
 

	
 
    @property
 
    def name(self):
 
        return self.group_name.split(RepoGroup.url_sep())[-1]
 

	
 
    @property
 
    def full_path(self):
 
        return self.group_name
 

	
 
    @property
 
    def full_path_splitted(self):
 
        return self.group_name.split(RepoGroup.url_sep())
 

	
 
    @property
 
    def repositories(self):
 
        return Repository.query()\
 
                .filter(Repository.group == self)\
 
                .order_by(Repository.repo_name)
 

	
 
    @property
 
    def repositories_recursive_count(self):
 
        cnt = self.repositories.count()
 

	
 
        def children_count(group):
 
            cnt = 0
 
            for child in group.children:
 
                cnt += child.repositories.count()
 
                cnt += children_count(child)
 
            return cnt
 

	
 
        return cnt + children_count(self)
 

	
 
    def _recursive_objects(self, include_repos=True):
 
        all_ = []
 

	
 
        def _get_members(root_gr):
 
            if include_repos:
 
                for r in root_gr.repositories:
 
                    all_.append(r)
 
            childs = root_gr.children.all()
 
            if childs:
 
                for gr in childs:
 
                    all_.append(gr)
 
                    _get_members(gr)
 

	
 
        _get_members(self)
 
        return [self] + all_
 

	
 
    def recursive_groups_and_repos(self):
 
        """
 
        Recursive return all groups, with repositories in those groups
 
        """
 
        return self._recursive_objects()
 

	
 
    def recursive_groups(self):
 
        """
 
        Returns all children groups for this group including children of children
 
        """
 
        return self._recursive_objects(include_repos=False)
 

	
 
    def get_new_name(self, group_name):
 
        """
 
        returns new full group name based on parent and new name
 

	
 
        :param group_name:
 
        """
 
        path_prefix = (self.parent_group.full_path_splitted if
 
                       self.parent_group else [])
 
        return RepoGroup.url_sep().join(path_prefix + [group_name])
 

	
 
    def get_api_data(self):
 
        """
 
        Common function for generating api data
 

	
 
        """
 
        group = self
 
        data = dict(
 
            group_id=group.group_id,
 
            group_name=group.group_name,
 
            group_description=group.group_description,
 
            parent_group=group.parent_group.group_name if group.parent_group else None,
 
            repositories=[x.repo_name for x in group.repositories],
 
            owner=group.user.username
 
        )
 
        return data
 

	
 

	
 
class Permission(Base, BaseModel):
 
    __tablename__ = 'permissions'
 
    __table_args__ = (
 
        Index('p_perm_name_idx', 'permission_name'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    PERMS = [
 
        ('hg.admin', _('Kallithea Administrator')),
 

	
 
        ('repository.none', _('Repository no access')),
 
        ('repository.read', _('Repository read access')),
 
        ('repository.write', _('Repository write access')),
 
        ('repository.admin', _('Repository admin access')),
 

	
 
        ('group.none', _('Repository group no access')),
 
        ('group.read', _('Repository group read access')),
 
        ('group.write', _('Repository group write access')),
 
        ('group.admin', _('Repository group admin access')),
 

	
 
        ('usergroup.none', _('User group no access')),
 
        ('usergroup.read', _('User group read access')),
 
        ('usergroup.write', _('User group write access')),
 
        ('usergroup.admin', _('User group admin access')),
 

	
 
        ('hg.repogroup.create.false', _('Repository Group creation disabled')),
 
        ('hg.repogroup.create.true', _('Repository Group creation enabled')),
 

	
 
        ('hg.usergroup.create.false', _('User Group creation disabled')),
 
        ('hg.usergroup.create.true', _('User Group creation enabled')),
 

	
 
        ('hg.create.none', _('Repository creation disabled')),
 
        ('hg.create.repository', _('Repository creation enabled')),
 
        ('hg.create.write_on_repogroup.true', _('Repository creation enabled with write permission to a repository group')),
 
        ('hg.create.write_on_repogroup.false', _('Repository creation disabled with write permission to a repository group')),
 

	
 
        ('hg.fork.none', _('Repository forking disabled')),
 
        ('hg.fork.repository', _('Repository forking enabled')),
 

	
 
        ('hg.register.none', _('Registration disabled')),
 
        ('hg.register.manual_activate', _('User Registration with manual account activation')),
 
        ('hg.register.auto_activate', _('User Registration with automatic account activation')),
 

	
 
        ('hg.extern_activate.manual', _('Manual activation of external account')),
 
        ('hg.extern_activate.auto', _('Automatic activation of external account')),
 

	
 
    ]
 

	
 
    #definition of system default permissions for DEFAULT user
 
    DEFAULT_USER_PERMISSIONS = [
 
        'repository.read',
 
        'group.read',
 
        'usergroup.read',
 
        'hg.create.repository',
 
        'hg.create.write_on_repogroup.true',
 
        'hg.fork.repository',
 
        'hg.register.manual_activate',
 
        'hg.extern_activate.auto',
 
    ]
 

	
 
    # defines which permissions are more important higher the more important
 
    # Weight defines which permissions are more important.
 
    # The higher number the more important.
 
    PERM_WEIGHTS = {
 
        'repository.none': 0,
 
        'repository.read': 1,
 
        'repository.write': 3,
 
        'repository.admin': 4,
 

	
 
        'group.none': 0,
 
        'group.read': 1,
 
        'group.write': 3,
 
        'group.admin': 4,
 

	
 
        'usergroup.none': 0,
 
        'usergroup.read': 1,
 
        'usergroup.write': 3,
 
        'usergroup.admin': 4,
 
        'hg.repogroup.create.false': 0,
 
        'hg.repogroup.create.true': 1,
 

	
 
        'hg.usergroup.create.false': 0,
 
        'hg.usergroup.create.true': 1,
 

	
 
        'hg.fork.none': 0,
 
        'hg.fork.repository': 1,
 
        'hg.create.none': 0,
 
        'hg.create.repository': 1
 
    }
 

	
 
    permission_id = Column("permission_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    permission_name = Column("permission_name", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    permission_longname = Column("permission_longname", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 

	
 
    def __unicode__(self):
 
        return u"<%s('%s:%s')>" % (
 
            self.__class__.__name__, self.permission_id, self.permission_name
 
        )
 

	
 
    @classmethod
 
    def get_by_key(cls, key):
 
        return cls.query().filter(cls.permission_name == key).scalar()
 

	
 
    @classmethod
 
    def get_default_perms(cls, default_user_id):
 
        q = Session().query(UserRepoToPerm, Repository, cls)\
 
         .join((Repository, UserRepoToPerm.repository_id == Repository.repo_id))\
 
         .join((cls, UserRepoToPerm.permission_id == cls.permission_id))\
 
         .filter(UserRepoToPerm.user_id == default_user_id)
 

	
 
        return q.all()
 

	
 
    @classmethod
 
    def get_default_group_perms(cls, default_user_id):
 
        q = Session().query(UserRepoGroupToPerm, RepoGroup, cls)\
 
         .join((RepoGroup, UserRepoGroupToPerm.group_id == RepoGroup.group_id))\
 
         .join((cls, UserRepoGroupToPerm.permission_id == cls.permission_id))\
 
         .filter(UserRepoGroupToPerm.user_id == default_user_id)
 

	
 
        return q.all()
 

	
 
    @classmethod
 
    def get_default_user_group_perms(cls, default_user_id):
 
        q = Session().query(UserUserGroupToPerm, UserGroup, cls)\
 
         .join((UserGroup, UserUserGroupToPerm.user_group_id == UserGroup.users_group_id))\
 
         .join((cls, UserUserGroupToPerm.permission_id == cls.permission_id))\
 
         .filter(UserUserGroupToPerm.user_id == default_user_id)
 

	
 
        return q.all()
 

	
 

	
 
class UserRepoToPerm(Base, BaseModel):
 
    __tablename__ = 'repo_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'repository_id', 'permission_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    repo_to_perm_id = Column("repo_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 
    repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relationship('User')
 
    repository = relationship('Repository')
 
    permission = relationship('Permission')
 

	
 
    @classmethod
 
    def create(cls, user, repository, permission):
 
        n = cls()
 
        n.user = user
 
        n.repository = repository
 
        n.permission = permission
 
        Session().add(n)
 
        return n
 

	
 
    def __unicode__(self):
 
        return u'<%s => %s >' % (self.user, self.repository)
 

	
 

	
 
class UserUserGroupToPerm(Base, BaseModel):
 
    __tablename__ = 'user_user_group_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'user_group_id', 'permission_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    user_user_group_to_perm_id = Column("user_user_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 
    user_group_id = Column("user_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relationship('User')
 
    user_group = relationship('UserGroup')
 
    permission = relationship('Permission')
 

	
 
    @classmethod
 
    def create(cls, user, user_group, permission):
 
        n = cls()
 
        n.user = user
 
        n.user_group = user_group
 
        n.permission = permission
 
        Session().add(n)
 
        return n
 

	
 
    def __unicode__(self):
 
        return u'<%s => %s >' % (self.user, self.user_group)
 

	
 

	
 
class UserToPerm(Base, BaseModel):
 
    __tablename__ = 'user_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'permission_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    user_to_perm_id = Column("user_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relationship('User')
 
    permission = relationship('Permission', lazy='joined')
 

	
 
    def __unicode__(self):
 
        return u'<%s => %s >' % (self.user, self.permission)
 

	
 

	
 
class UserGroupRepoToPerm(Base, BaseModel):
 
    __tablename__ = 'users_group_repo_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('repository_id', 'users_group_id', 'permission_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    users_group_to_perm_id = Column("users_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 
    repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=None, default=None)
 

	
 
    users_group = relationship('UserGroup')
 
    permission = relationship('Permission')
 
    repository = relationship('Repository')
 

	
 
    @classmethod
 
    def create(cls, users_group, repository, permission):
 
        n = cls()
 
        n.users_group = users_group
 
        n.repository = repository
 
        n.permission = permission
 
        Session().add(n)
 
        return n
 

	
 
    def __unicode__(self):
 
        return u'<UserGroupRepoToPerm:%s => %s >' % (self.users_group, self.repository)
 

	
 

	
 
class UserGroupUserGroupToPerm(Base, BaseModel):
 
    __tablename__ = 'user_group_user_group_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('target_user_group_id', 'user_group_id', 'permission_id'),
 
        CheckConstraint('target_user_group_id != user_group_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    user_group_user_group_to_perm_id = Column("user_group_user_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    target_user_group_id = Column("target_user_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 
    user_group_id = Column("user_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 

	
 
    target_user_group = relationship('UserGroup', primaryjoin='UserGroupUserGroupToPerm.target_user_group_id==UserGroup.users_group_id')
 
    user_group = relationship('UserGroup', primaryjoin='UserGroupUserGroupToPerm.user_group_id==UserGroup.users_group_id')
 
    permission = relationship('Permission')
 

	
 
    @classmethod
 
    def create(cls, target_user_group, user_group, permission):
 
        n = cls()
 
        n.target_user_group = target_user_group
 
        n.user_group = user_group
 
        n.permission = permission
 
        Session().add(n)
 
        return n
 

	
 
    def __unicode__(self):
 
        return u'<UserGroupUserGroup:%s => %s >' % (self.target_user_group, self.user_group)
 

	
 

	
 
class UserGroupToPerm(Base, BaseModel):
 
    __tablename__ = 'users_group_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('users_group_id', 'permission_id',),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    users_group_to_perm_id = Column("users_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 

	
 
    users_group = relationship('UserGroup')
 
    permission = relationship('Permission')
 

	
 

	
 
class UserRepoGroupToPerm(Base, BaseModel):
 
    __tablename__ = 'user_repo_group_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'group_id', 'permission_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 

	
 
    group_to_perm_id = Column("group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
 
    group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 

	
 
    user = relationship('User')
 
    group = relationship('RepoGroup')
 
    permission = relationship('Permission')
 

	
 
    @classmethod
 
    def create(cls, user, repository_group, permission):
 
        n = cls()
 
        n.user = user
 
        n.group = repository_group
 
        n.permission = permission
 
        Session().add(n)
 
        return n
 

	
 

	
 
class UserGroupRepoGroupToPerm(Base, BaseModel):
 
    __tablename__ = 'users_group_repo_group_to_perm'
 
    __table_args__ = (
 
        UniqueConstraint('users_group_id', 'group_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 

	
 
    users_group_repo_group_to_perm_id = Column("users_group_repo_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
 
    group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'), nullable=False, unique=None, default=None)
 
    permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
 

	
 
    users_group = relationship('UserGroup')
 
    permission = relationship('Permission')
 
    group = relationship('RepoGroup')
 

	
 
    @classmethod
 
    def create(cls, user_group, repository_group, permission):
 
        n = cls()
 
        n.users_group = user_group
 
        n.group = repository_group
 
        n.permission = permission
 
        Session().add(n)
 
        return n
 

	
 

	
 
class Statistics(Base, BaseModel):
 
    __tablename__ = 'statistics'
 
    __table_args__ = (
 
         UniqueConstraint('repository_id'),
 
         {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
          'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    stat_id = Column("stat_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=True, default=None)
 
    stat_on_revision = Column("stat_on_revision", Integer(), nullable=False)
 
    commit_activity = Column("commit_activity", LargeBinary(1000000), nullable=False)#JSON data
 
    commit_activity_combined = Column("commit_activity_combined", LargeBinary(), nullable=False)#JSON data
 
    languages = Column("languages", LargeBinary(1000000), nullable=False)#JSON data
 

	
 
    repository = relationship('Repository', single_parent=True)
 

	
 

	
 
class UserFollowing(Base, BaseModel):
 
    __tablename__ = 'user_followings'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'follows_repository_id'),
 
        UniqueConstraint('user_id', 'follows_user_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 

	
 
    user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
 
    follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=True, unique=None, default=None)
 
    follows_user_id = Column("follows_user_id", Integer(), ForeignKey('users.user_id'), nullable=True, unique=None, default=None)
 
    follows_from = Column('follows_from', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
 

	
 
    user = relationship('User', primaryjoin='User.user_id==UserFollowing.user_id')
 

	
 
    follows_user = relationship('User', primaryjoin='User.user_id==UserFollowing.follows_user_id')
 
    follows_repository = relationship('Repository', order_by='Repository.repo_name')
 

	
 
    @classmethod
 
    def get_repo_followers(cls, repo_id):
 
        return cls.query().filter(cls.follows_repo_id == repo_id)
 

	
 

	
 
class CacheInvalidation(Base, BaseModel):
 
    __tablename__ = 'cache_invalidation'
 
    __table_args__ = (
 
        UniqueConstraint('cache_key'),
 
        Index('key_idx', 'cache_key'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    # cache_id, not used
 
    cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
 
    # cache_key as created by _get_cache_key
 
    cache_key = Column("cache_key", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    # cache_args is a repo_name
 
    cache_args = Column("cache_args", String(255, convert_unicode=False), nullable=True, unique=None, default=None)
 
    # instance sets cache_active True when it is caching,
 
    # other instances set cache_active to False to indicate that this cache is invalid
 
    cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False)
 

	
 
    def __init__(self, cache_key, repo_name=''):
 
        self.cache_key = cache_key
 
        self.cache_args = repo_name
 
        self.cache_active = False
 

	
 
    def __unicode__(self):
 
        return u"<%s('%s:%s[%s]')>" % (self.__class__.__name__,
 
                            self.cache_id, self.cache_key, self.cache_active)
 

	
 
    def _cache_key_partition(self):
 
        prefix, repo_name, suffix = self.cache_key.partition(self.cache_args)
 
        return prefix, repo_name, suffix
 

	
 
    def get_prefix(self):
 
        """
 
        get prefix that might have been used in _get_cache_key to
 
        generate self.cache_key. Only used for informational purposes
 
        in repo_edit.html.
 
        """
 
        # prefix, repo_name, suffix
 
        return self._cache_key_partition()[0]
 

	
 
    def get_suffix(self):
 
        """
 
        get suffix that might have been used in _get_cache_key to
 
        generate self.cache_key. Only used for informational purposes
 
        in repo_edit.html.
 
        """
 
        # prefix, repo_name, suffix
 
        return self._cache_key_partition()[2]
 

	
 
    @classmethod
 
    def clear_cache(cls):
 
        """
 
        Delete all cache keys from database.
 
        Should only be run when all instances are down and all entries thus stale.
 
        """
 
        cls.query().delete()
 
        Session().commit()
 

	
 
    @classmethod
 
    def _get_cache_key(cls, key):
 
        """
 
        Wrapper for generating a unique cache key for this instance and "key".
 
        key must / will start with a repo_name which will be stored in .cache_args .
 
        """
 
        import kallithea
 
        prefix = kallithea.CONFIG.get('instance_id', '')
 
        return "%s%s" % (prefix, key)
 

	
 
    @classmethod
 
    def set_invalidate(cls, repo_name, delete=False):
 
        """
 
        Mark all caches of a repo as invalid in the database.
 
        """
 
        inv_objs = Session().query(cls).filter(cls.cache_args == repo_name).all()
 
        log.debug('for repo %s got %s invalidation objects'
 
                  % (safe_str(repo_name), inv_objs))
 
        try:
 
            for inv_obj in inv_objs:
 
                log.debug('marking %s key for invalidation based on repo_name=%s'
 
                          % (inv_obj, safe_str(repo_name)))
 
                if delete:
 
                    Session().delete(inv_obj)
 
                else:
 
                    inv_obj.cache_active = False
 
                    Session().add(inv_obj)
 
            Session().commit()
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            Session().rollback()
 

	
 
        for inv_obj in inv_objs:
 
            log.debug('marking %s key for invalidation based on repo_name=%s'
 
                      % (inv_obj, safe_str(repo_name)))
 
            if delete:
 
                Session().delete(inv_obj)
 
            else:
 
                inv_obj.cache_active = False
 
                Session().add(inv_obj)
 
        Session().commit()
 

	
 
    @classmethod
 
    def test_and_set_valid(cls, repo_name, kind, valid_cache_keys=None):
 
        """
 
        Mark this cache key as active and currently cached.
 
        Return True if the existing cache registration still was valid.
 
        Return False to indicate that it had been invalidated and caches should be refreshed.
 
        """
 

	
 
        key = (repo_name + '_' + kind) if kind else repo_name
 
        cache_key = cls._get_cache_key(key)
 

	
 
        if valid_cache_keys and cache_key in valid_cache_keys:
 
            return True
 

	
 
        try:
 
            inv_obj = cls.query().filter(cls.cache_key == cache_key).scalar()
 
            if not inv_obj:
 
                inv_obj = CacheInvalidation(cache_key, repo_name)
 
            if inv_obj.cache_active:
 
                return True
 
            inv_obj.cache_active = True
 
            Session().add(inv_obj)
 
            Session().commit()
 
            return False
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            Session().rollback()
 
            return False
 
        inv_obj = cls.query().filter(cls.cache_key == cache_key).scalar()
 
        if not inv_obj:
 
            inv_obj = CacheInvalidation(cache_key, repo_name)
 
        if inv_obj.cache_active:
 
            return True
 
        inv_obj.cache_active = True
 
        Session().add(inv_obj)
 
        Session().commit()
 
        return False
 

	
 
    @classmethod
 
    def get_valid_cache_keys(cls):
 
        """
 
        Return opaque object with information of which caches still are valid
 
        and can be used without checking for invalidation.
 
        """
 
        return set(inv_obj.cache_key for inv_obj in cls.query().filter(cls.cache_active).all())
 

	
 

	
 
class ChangesetComment(Base, BaseModel):
 
    __tablename__ = 'changeset_comments'
 
    __table_args__ = (
 
        Index('cc_revision_idx', 'revision'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    comment_id = Column('comment_id', Integer(), nullable=False, primary_key=True)
 
    repo_id = Column('repo_id', Integer(), ForeignKey('repositories.repo_id'), nullable=False)
 
    revision = Column('revision', String(40), nullable=True)
 
    pull_request_id = Column("pull_request_id", Integer(), ForeignKey('pull_requests.pull_request_id'), nullable=True)
 
    line_no = Column('line_no', Unicode(10), nullable=True)
 
    hl_lines = Column('hl_lines', Unicode(512), nullable=True)
 
    f_path = Column('f_path', Unicode(1000), nullable=True)
 
    user_id = Column('user_id', Integer(), ForeignKey('users.user_id'), nullable=False)
 
    text = Column('text', UnicodeText(25000), nullable=False)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    modified_at = Column('modified_at', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 

	
 
    author = relationship('User', lazy='joined')
 
    repo = relationship('Repository')
 
    status_change = relationship('ChangesetStatus', cascade="all, delete-orphan")
 
    pull_request = relationship('PullRequest')
 

	
 
    @classmethod
 
    def get_users(cls, revision=None, pull_request_id=None):
 
        """
 
        Returns user associated with this ChangesetComment. ie those
 
        who actually commented
 

	
 
        :param cls:
 
        :param revision:
 
        """
 
        q = Session().query(User)\
 
                .join(ChangesetComment.author)
 
        if revision:
 
            q = q.filter(cls.revision == revision)
 
        elif pull_request_id:
 
            q = q.filter(cls.pull_request_id == pull_request_id)
 
        return q.all()
 

	
 

	
 
class ChangesetStatus(Base, BaseModel):
 
    __tablename__ = 'changeset_statuses'
 
    __table_args__ = (
 
        Index('cs_revision_idx', 'revision'),
 
        Index('cs_version_idx', 'version'),
 
        UniqueConstraint('repo_id', 'revision', 'version'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    STATUS_NOT_REVIEWED = DEFAULT = 'not_reviewed'
 
    STATUS_APPROVED = 'approved'
 
    STATUS_REJECTED = 'rejected'
 
    STATUS_UNDER_REVIEW = 'under_review'
 

	
 
    STATUSES = [
 
        (STATUS_NOT_REVIEWED, _("Not Reviewed")),  # (no icon) and default
 
        (STATUS_APPROVED, _("Approved")),
 
        (STATUS_REJECTED, _("Rejected")),
 
        (STATUS_UNDER_REVIEW, _("Under Review")),
 
    ]
 

	
 
    changeset_status_id = Column('changeset_status_id', Integer(), nullable=False, primary_key=True)
 
    repo_id = Column('repo_id', Integer(), ForeignKey('repositories.repo_id'), nullable=False)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None)
 
    revision = Column('revision', String(40), nullable=False)
 
    status = Column('status', String(128), nullable=False, default=DEFAULT)
 
    changeset_comment_id = Column('changeset_comment_id', Integer(), ForeignKey('changeset_comments.comment_id'))
 
    modified_at = Column('modified_at', DateTime(), nullable=False, default=datetime.datetime.now)
 
    version = Column('version', Integer(), nullable=False, default=0)
 
    pull_request_id = Column("pull_request_id", Integer(), ForeignKey('pull_requests.pull_request_id'), nullable=True)
 

	
 
    author = relationship('User', lazy='joined')
 
    repo = relationship('Repository')
 
    comment = relationship('ChangesetComment', lazy='joined')
 
    pull_request = relationship('PullRequest')
 

	
 
    def __unicode__(self):
 
        return u"<%s('%s:%s')>" % (
 
            self.__class__.__name__,
 
            self.status, self.author
 
        )
 

	
 
    @classmethod
 
    def get_status_lbl(cls, value):
 
        return dict(cls.STATUSES).get(value)
 

	
 
    @property
 
    def status_lbl(self):
 
        return ChangesetStatus.get_status_lbl(self.status)
 

	
 

	
 
class PullRequest(Base, BaseModel):
 
    __tablename__ = 'pull_requests'
 
    __table_args__ = (
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 

	
 
    # values for .status
 
    STATUS_NEW = u'new'
 
    STATUS_CLOSED = u'closed'
 

	
 
    pull_request_id = Column('pull_request_id', Integer(), nullable=False, primary_key=True)
 
    title = Column('title', Unicode(256), nullable=True)
 
    description = Column('description', UnicodeText(10240), nullable=True)
 
    status = Column('status', Unicode(256), nullable=False, default=STATUS_NEW) # only for closedness, not approve/reject/etc
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    updated_on = Column('updated_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None)
 
    _revisions = Column('revisions', UnicodeText(20500))  # 500 revisions max
 
    org_repo_id = Column('org_repo_id', Integer(), ForeignKey('repositories.repo_id'), nullable=False)
 
    org_ref = Column('org_ref', Unicode(256), nullable=False)
 
    other_repo_id = Column('other_repo_id', Integer(), ForeignKey('repositories.repo_id'), nullable=False)
 
    other_ref = Column('other_ref', Unicode(256), nullable=False)
 

	
 
    @hybrid_property
 
    def revisions(self):
 
        return self._revisions.split(':')
 

	
 
    @revisions.setter
 
    def revisions(self, val):
 
        self._revisions = safe_unicode(':'.join(val))
 

	
 
    @property
 
    def org_ref_parts(self):
 
        return self.org_ref.split(':')
 

	
 
    @property
 
    def other_ref_parts(self):
 
        return self.other_ref.split(':')
 

	
 
    author = relationship('User', lazy='joined')
 
    reviewers = relationship('PullRequestReviewers',
 
                             cascade="all, delete-orphan")
 
    org_repo = relationship('Repository', primaryjoin='PullRequest.org_repo_id==Repository.repo_id')
 
    other_repo = relationship('Repository', primaryjoin='PullRequest.other_repo_id==Repository.repo_id')
 
    statuses = relationship('ChangesetStatus')
 
    comments = relationship('ChangesetComment',
 
                             cascade="all, delete-orphan")
 

	
 
    def is_closed(self):
 
        return self.status == self.STATUS_CLOSED
 

	
 
    @property
 
    def last_review_status(self):
 
        return str(self.statuses[-1].status) if self.statuses else ''
 

	
 
    def __json__(self):
 
        return dict(
 
            revisions=self.revisions
 
        )
 

	
 
    def url(self, **kwargs):
 
        canonical = kwargs.pop('canonical', None)
 
        import kallithea.lib.helpers as h
 
        b = self.org_ref_parts[1]
 
        if b != self.other_ref_parts[1]:
 
            s = '/_/' + b
 
        else:
 
            s = '/_/' + self.title
 
        kwargs['extra'] = urlreadable(s)
 
        if canonical:
 
            return h.canonical_url('pullrequest_show', repo_name=self.other_repo.repo_name,
 
                                   pull_request_id=self.pull_request_id, **kwargs)
 
        return h.url('pullrequest_show', repo_name=self.other_repo.repo_name,
 
                     pull_request_id=self.pull_request_id, **kwargs)
 

	
 
class PullRequestReviewers(Base, BaseModel):
 
    __tablename__ = 'pull_request_reviewers'
 
    __table_args__ = (
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 

	
 
    def __init__(self, user=None, pull_request=None):
 
        self.user = user
 
        self.pull_request = pull_request
 

	
 
    pull_requests_reviewers_id = Column('pull_requests_reviewers_id', Integer(), nullable=False, primary_key=True)
 
    pull_request_id = Column("pull_request_id", Integer(), ForeignKey('pull_requests.pull_request_id'), nullable=False)
 
    user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=True)
 

	
 
    user = relationship('User')
 
    pull_request = relationship('PullRequest')
 

	
 

	
 
class Notification(Base, BaseModel):
 
    __tablename__ = 'notifications'
 
    __table_args__ = (
 
        Index('notification_type_idx', 'type'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 

	
 
    TYPE_CHANGESET_COMMENT = u'cs_comment'
 
    TYPE_MESSAGE = u'message'
 
    TYPE_MENTION = u'mention'
 
    TYPE_REGISTRATION = u'registration'
 
    TYPE_PULL_REQUEST = u'pull_request'
 
    TYPE_PULL_REQUEST_COMMENT = u'pull_request_comment'
 

	
 
    notification_id = Column('notification_id', Integer(), nullable=False, primary_key=True)
 
    subject = Column('subject', Unicode(512), nullable=True)
 
    body = Column('body', UnicodeText(50000), nullable=True)
 
    created_by = Column("created_by", Integer(), ForeignKey('users.user_id'), nullable=True)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    type_ = Column('type', Unicode(256))
 

	
 
    created_by_user = relationship('User')
 
    notifications_to_users = relationship('UserNotification', lazy='joined',
 
                                          cascade="all, delete-orphan")
 

	
 
    @property
 
    def recipients(self):
 
        return [x.user for x in UserNotification.query()\
 
                .filter(UserNotification.notification == self)\
 
                .order_by(UserNotification.user_id.asc()).all()]
 

	
 
    @classmethod
 
    def create(cls, created_by, subject, body, recipients, type_=None):
 
        if type_ is None:
 
            type_ = Notification.TYPE_MESSAGE
 

	
 
        notification = cls()
 
        notification.created_by_user = created_by
 
        notification.subject = subject
 
        notification.body = body
 
        notification.type_ = type_
 
        notification.created_on = datetime.datetime.now()
 

	
 
        for u in recipients:
 
            assoc = UserNotification()
 
            assoc.notification = notification
 
            u.notifications.append(assoc)
 
        Session().add(notification)
 
        return notification
 

	
 
    @property
 
    def description(self):
 
        from kallithea.model.notification import NotificationModel
 
        return NotificationModel().make_description(self)
 

	
 

	
 
class UserNotification(Base, BaseModel):
 
    __tablename__ = 'user_to_notification'
 
    __table_args__ = (
 
        UniqueConstraint('user_id', 'notification_id'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    user_id = Column('user_id', Integer(), ForeignKey('users.user_id'), primary_key=True)
 
    notification_id = Column("notification_id", Integer(), ForeignKey('notifications.notification_id'), primary_key=True)
 
    read = Column('read', Boolean, default=False)
 
    sent_on = Column('sent_on', DateTime(timezone=False), nullable=True, unique=None)
 

	
 
    user = relationship('User', lazy="joined")
 
    notification = relationship('Notification', lazy="joined",
 
                                order_by=lambda: Notification.created_on.desc(),)
 

	
 
    def mark_as_read(self):
 
        self.read = True
 
        Session().add(self)
 

	
 

	
 
class Gist(Base, BaseModel):
 
    __tablename__ = 'gists'
 
    __table_args__ = (
 
        Index('g_gist_access_id_idx', 'gist_access_id'),
 
        Index('g_created_on_idx', 'created_on'),
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True}
 
    )
 
    GIST_PUBLIC = u'public'
 
    GIST_PRIVATE = u'private'
 
    DEFAULT_FILENAME = u'gistfile1.txt'
 

	
 
    gist_id = Column('gist_id', Integer(), primary_key=True)
 
    gist_access_id = Column('gist_access_id', Unicode(250))
 
    gist_description = Column('gist_description', UnicodeText(1024))
 
    gist_owner = Column('user_id', Integer(), ForeignKey('users.user_id'), nullable=True)
 
    gist_expires = Column('gist_expires', Float(53), nullable=False)
 
    gist_type = Column('gist_type', Unicode(128), nullable=False)
 
    created_on = Column('created_on', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 
    modified_at = Column('modified_at', DateTime(timezone=False), nullable=False, default=datetime.datetime.now)
 

	
 
    owner = relationship('User')
 

	
 
    def __repr__(self):
 
        return '<Gist:[%s]%s>' % (self.gist_type, self.gist_access_id)
 

	
 
    @classmethod
 
    def get_or_404(cls, id_):
 
        res = cls.query().filter(cls.gist_access_id == id_).scalar()
 
        if not res:
 
            raise HTTPNotFound
 
        return res
 

	
 
    @classmethod
 
    def get_by_access_id(cls, gist_access_id):
 
        return cls.query().filter(cls.gist_access_id == gist_access_id).scalar()
 

	
 
    def gist_url(self):
 
        import kallithea
 
        alias_url = kallithea.CONFIG.get('gist_alias_url')
 
        if alias_url:
 
            return alias_url.replace('{gistid}', self.gist_access_id)
 

	
 
        import kallithea.lib.helpers as h
 
        return h.canonical_url('gist', gist_id=self.gist_access_id)
 

	
 
    @classmethod
 
    def base_path(cls):
 
        """
 
        Returns base path where all gists are stored
 

	
 
        :param cls:
 
        """
 
        from kallithea.model.gist import GIST_STORE_LOC
 
        q = Session().query(Ui)\
 
            .filter(Ui.ui_key == URL_SEP)
 
        q = q.options(FromCache("sql_cache_short", "repository_repo_path"))
 
        return os.path.join(q.one().ui_value, GIST_STORE_LOC)
 

	
 
    def get_api_data(self):
 
        """
 
        Common function for generating gist related data for API
 
        """
 
        gist = self
 
        data = dict(
 
            gist_id=gist.gist_id,
 
            type=gist.gist_type,
 
            access_id=gist.gist_access_id,
 
            description=gist.gist_description,
 
            url=gist.gist_url(),
 
            expires=gist.gist_expires,
 
            created_on=gist.created_on,
 
        )
 
        return data
 

	
 
    def __json__(self):
 
        data = dict(
 
        )
 
        data.update(self.get_api_data())
 
        return data
 
    ## SCM functions
 

	
 
    @property
 
    def scm_instance(self):
 
        from kallithea.lib.vcs import get_repo
 
        base_path = self.base_path()
 
        return get_repo(os.path.join(*map(safe_str,
 
                                          [base_path, self.gist_access_id])))
 

	
 

	
 
class DbMigrateVersion(Base, BaseModel):
 
    __tablename__ = 'db_migrate_version'
 
    __table_args__ = (
 
        {'extend_existing': True, 'mysql_engine': 'InnoDB',
 
         'mysql_charset': 'utf8', 'sqlite_autoincrement': True},
 
    )
 
    repository_id = Column('repository_id', String(250), primary_key=True)
 
    repository_path = Column('repository_path', Text)
 
    version = Column('version', Integer)
kallithea/model/user.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.model.user
 
~~~~~~~~~~~~~~~~~~~~
 

	
 
users model for Kallithea
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Apr 9, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 
"""
 

	
 

	
 
import logging
 
import traceback
 
from pylons.i18n.translation import _
 

	
 
from sqlalchemy.exc import DatabaseError
 

	
 
from kallithea import EXTERN_TYPE_INTERNAL
 
from kallithea.lib.utils2 import safe_unicode, generate_api_key, get_current_authuser
 
from kallithea.lib.caching_query import FromCache
 
from kallithea.model import BaseModel
 
from kallithea.model.db import User, UserToPerm, Notification, \
 
    UserEmailMap, UserIpMap
 
from kallithea.lib.exceptions import DefaultUserException, \
 
    UserOwnsReposException
 
from kallithea.model.meta import Session
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class UserModel(BaseModel):
 
    cls = User
 

	
 
    def get(self, user_id, cache=False):
 
        user = self.sa.query(User)
 
        if cache:
 
            user = user.options(FromCache("sql_cache_short",
 
                                          "get_user_%s" % user_id))
 
        return user.get(user_id)
 

	
 
    def get_user(self, user):
 
        return self._get_user(user)
 

	
 
    def get_by_username(self, username, cache=False, case_insensitive=False):
 

	
 
        if case_insensitive:
 
            user = self.sa.query(User).filter(User.username.ilike(username))
 
        else:
 
            user = self.sa.query(User)\
 
                .filter(User.username == username)
 
        if cache:
 
            user = user.options(FromCache("sql_cache_short",
 
                                          "get_user_%s" % username))
 
        return user.scalar()
 

	
 
    def get_by_email(self, email, cache=False, case_insensitive=False):
 
        return User.get_by_email(email, case_insensitive, cache)
 

	
 
    def get_by_api_key(self, api_key, cache=False):
 
        return User.get_by_api_key(api_key, cache)
 

	
 
    def create(self, form_data, cur_user=None):
 
        if not cur_user:
 
            cur_user = getattr(get_current_authuser(), 'username', None)
 

	
 
        from kallithea.lib.hooks import log_create_user, check_allowed_create_user
 
        _fd = form_data
 
        user_data = {
 
            'username': _fd['username'], 'password': _fd['password'],
 
            'email': _fd['email'], 'firstname': _fd['firstname'], 'lastname': _fd['lastname'],
 
            'active': _fd['active'], 'admin': False
 
        }
 
        # raises UserCreationError if it's not allowed
 
        check_allowed_create_user(user_data, cur_user)
 
        from kallithea.lib.auth import get_crypt_password
 
        try:
 
            new_user = User()
 
            for k, v in form_data.items():
 
                if k == 'password':
 
                    v = get_crypt_password(v)
 
                if k == 'firstname':
 
                    k = 'name'
 
                setattr(new_user, k, v)
 

	
 
            new_user.api_key = generate_api_key(form_data['username'])
 
            self.sa.add(new_user)
 
        new_user = User()
 
        for k, v in form_data.items():
 
            if k == 'password':
 
                v = get_crypt_password(v)
 
            if k == 'firstname':
 
                k = 'name'
 
            setattr(new_user, k, v)
 

	
 
            log_create_user(new_user.get_dict(), cur_user)
 
            return new_user
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            raise
 
        new_user.api_key = generate_api_key(form_data['username'])
 
        self.sa.add(new_user)
 

	
 
        log_create_user(new_user.get_dict(), cur_user)
 
        return new_user
 

	
 
    def create_or_update(self, username, password, email, firstname='',
 
                         lastname='', active=True, admin=False,
 
                         extern_type=None, extern_name=None, cur_user=None):
 
        """
 
        Creates a new instance if not found, or updates current one
 

	
 
        :param username:
 
        :param password:
 
        :param email:
 
        :param active:
 
        :param firstname:
 
        :param lastname:
 
        :param active:
 
        :param admin:
 
        :param extern_name:
 
        :param extern_type:
 
        :param cur_user:
 
        """
 
        if not cur_user:
 
            cur_user = getattr(get_current_authuser(), 'username', None)
 

	
 
        from kallithea.lib.auth import get_crypt_password, check_password
 
        from kallithea.lib.hooks import log_create_user, check_allowed_create_user
 
        user_data = {
 
            'username': username, 'password': password,
 
            'email': email, 'firstname': firstname, 'lastname': lastname,
 
            'active': active, 'admin': admin
 
        }
 
        # raises UserCreationError if it's not allowed
 
        check_allowed_create_user(user_data, cur_user)
 

	
 
        log.debug('Checking for %s account in Kallithea database' % username)
 
        user = User.get_by_username(username, case_insensitive=True)
 
        if user is None:
 
            log.debug('creating new user %s' % username)
 
            new_user = User()
 
            edit = False
 
        else:
 
            log.debug('updating user %s' % username)
 
            new_user = user
 
            edit = True
 

	
 
        try:
 
            new_user.username = username
 
            new_user.admin = admin
 
            new_user.email = email
 
            new_user.active = active
 
            new_user.extern_name = safe_unicode(extern_name) if extern_name else None
 
            new_user.extern_type = safe_unicode(extern_type) if extern_type else None
 
            new_user.name = firstname
 
            new_user.lastname = lastname
 

	
 
            if not edit:
 
                new_user.api_key = generate_api_key(username)
 

	
 
            # set password only if creating an user or password is changed
 
            password_change = new_user.password and not check_password(password,
 
                                                            new_user.password)
 
            if not edit or password_change:
 
                reason = 'new password' if edit else 'new user'
 
                log.debug('Updating password reason=>%s' % (reason,))
 
                new_user.password = get_crypt_password(password) if password else None
 

	
 
            self.sa.add(new_user)
 

	
 
            if not edit:
 
                log_create_user(new_user.get_dict(), cur_user)
 
            return new_user
 
        except (DatabaseError,):
 
            log.error(traceback.format_exc())
 
            raise
 

	
 
    def create_registration(self, form_data):
 
        from kallithea.model.notification import NotificationModel
 
        import kallithea.lib.helpers as h
 

	
 
        try:
 
            form_data['admin'] = False
 
            form_data['extern_name'] = EXTERN_TYPE_INTERNAL
 
            form_data['extern_type'] = EXTERN_TYPE_INTERNAL
 
            new_user = self.create(form_data)
 
        form_data['admin'] = False
 
        form_data['extern_name'] = EXTERN_TYPE_INTERNAL
 
        form_data['extern_type'] = EXTERN_TYPE_INTERNAL
 
        new_user = self.create(form_data)
 

	
 
            self.sa.add(new_user)
 
            self.sa.flush()
 
        self.sa.add(new_user)
 
        self.sa.flush()
 

	
 
            # notification to admins
 
            subject = _('New user registration')
 
            body = ('New user registration\n'
 
                    '---------------------\n'
 
                    '- Username: %s\n'
 
                    '- Full Name: %s\n'
 
                    '- Email: %s\n')
 
            body = body % (new_user.username, new_user.full_name, new_user.email)
 
            edit_url = h.canonical_url('edit_user', id=new_user.user_id)
 
            email_kwargs = {'registered_user_url': edit_url, 'new_username': new_user.username}
 
            NotificationModel().create(created_by=new_user, subject=subject,
 
                                       body=body, recipients=None,
 
                                       type_=Notification.TYPE_REGISTRATION,
 
                                       email_kwargs=email_kwargs)
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            raise
 
        # notification to admins
 
        subject = _('New user registration')
 
        body = ('New user registration\n'
 
                '---------------------\n'
 
                '- Username: %s\n'
 
                '- Full Name: %s\n'
 
                '- Email: %s\n')
 
        body = body % (new_user.username, new_user.full_name, new_user.email)
 
        edit_url = h.canonical_url('edit_user', id=new_user.user_id)
 
        email_kwargs = {'registered_user_url': edit_url, 'new_username': new_user.username}
 
        NotificationModel().create(created_by=new_user, subject=subject,
 
                                   body=body, recipients=None,
 
                                   type_=Notification.TYPE_REGISTRATION,
 
                                   email_kwargs=email_kwargs)
 

	
 
    def update(self, user_id, form_data, skip_attrs=[]):
 
        from kallithea.lib.auth import get_crypt_password
 
        try:
 
            user = self.get(user_id, cache=False)
 
            if user.username == User.DEFAULT_USER:
 
                raise DefaultUserException(
 
                                _("You can't Edit this user since it's "
 
                                  "crucial for entire application"))
 

	
 
        user = self.get(user_id, cache=False)
 
        if user.username == User.DEFAULT_USER:
 
            raise DefaultUserException(
 
                            _("You can't Edit this user since it's "
 
                              "crucial for entire application"))
 

	
 
            for k, v in form_data.items():
 
                if k in skip_attrs:
 
                    continue
 
                if k == 'new_password' and v:
 
                    user.password = get_crypt_password(v)
 
                else:
 
                    # old legacy thing orm models store firstname as name,
 
                    # need proper refactor to username
 
                    if k == 'firstname':
 
                        k = 'name'
 
                    setattr(user, k, v)
 
            self.sa.add(user)
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            raise
 
        for k, v in form_data.items():
 
            if k in skip_attrs:
 
                continue
 
            if k == 'new_password' and v:
 
                user.password = get_crypt_password(v)
 
            else:
 
                # old legacy thing orm models store firstname as name,
 
                # need proper refactor to username
 
                if k == 'firstname':
 
                    k = 'name'
 
                setattr(user, k, v)
 
        self.sa.add(user)
 

	
 
    def update_user(self, user, **kwargs):
 
        from kallithea.lib.auth import get_crypt_password
 
        try:
 
            user = self._get_user(user)
 
            if user.username == User.DEFAULT_USER:
 
                raise DefaultUserException(
 
                    _("You can't Edit this user since it's"
 
                      " crucial for entire application")
 
                )
 

	
 
        user = self._get_user(user)
 
        if user.username == User.DEFAULT_USER:
 
            raise DefaultUserException(
 
                _("You can't Edit this user since it's"
 
                  " crucial for entire application")
 
            )
 

	
 
            for k, v in kwargs.items():
 
                if k == 'password' and v:
 
                    v = get_crypt_password(v)
 
        for k, v in kwargs.items():
 
            if k == 'password' and v:
 
                v = get_crypt_password(v)
 

	
 
                setattr(user, k, v)
 
            self.sa.add(user)
 
            return user
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            raise
 
            setattr(user, k, v)
 
        self.sa.add(user)
 
        return user
 

	
 
    def delete(self, user, cur_user=None):
 
        if not cur_user:
 
            cur_user = getattr(get_current_authuser(), 'username', None)
 
        user = self._get_user(user)
 

	
 
        try:
 
            if user.username == User.DEFAULT_USER:
 
                raise DefaultUserException(
 
                    _(u"You can't remove this user since it's"
 
                      " crucial for entire application")
 
                )
 
            if user.repositories:
 
                repos = [x.repo_name for x in user.repositories]
 
                raise UserOwnsReposException(
 
                    _(u'user "%s" still owns %s repositories and cannot be '
 
                      'removed. Switch owners or remove those repositories. %s')
 
                    % (user.username, len(repos), ', '.join(repos))
 
                )
 
            self.sa.delete(user)
 
        if user.username == User.DEFAULT_USER:
 
            raise DefaultUserException(
 
                _(u"You can't remove this user since it's"
 
                  " crucial for entire application")
 
            )
 
        if user.repositories:
 
            repos = [x.repo_name for x in user.repositories]
 
            raise UserOwnsReposException(
 
                _(u'user "%s" still owns %s repositories and cannot be '
 
                  'removed. Switch owners or remove those repositories. %s')
 
                % (user.username, len(repos), ', '.join(repos))
 
            )
 
        self.sa.delete(user)
 

	
 
            from kallithea.lib.hooks import log_delete_user
 
            log_delete_user(user.get_dict(), cur_user)
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            raise
 
        from kallithea.lib.hooks import log_delete_user
 
        log_delete_user(user.get_dict(), cur_user)
 

	
 
    def reset_password_link(self, data):
 
        from kallithea.lib.celerylib import tasks, run_task
 
        from kallithea.model.notification import EmailNotificationModel
 
        import kallithea.lib.helpers as h
 

	
 
        user_email = data['email']
 
        try:
 
            user = User.get_by_email(user_email)
 
            if user:
 
                log.debug('password reset user found %s' % user)
 
                link = h.canonical_url('reset_password_confirmation', key=user.api_key)
 
                reg_type = EmailNotificationModel.TYPE_PASSWORD_RESET
 
                body = EmailNotificationModel().get_email_tmpl(reg_type,
 
                                                               'txt',
 
                                                               user=user.short_contact,
 
                                                               reset_url=link)
 
                html_body = EmailNotificationModel().get_email_tmpl(reg_type,
 
                                                               'html',
 
                                                               user=user.short_contact,
 
                                                               reset_url=link)
 
                log.debug('sending email')
 
                run_task(tasks.send_email, [user_email],
 
                         _("Password reset link"), body, html_body)
 
                log.info('send new password mail to %s' % user_email)
 
            else:
 
                log.debug("password reset email %s not found" % user_email)
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            return False
 
        user = User.get_by_email(user_email)
 
        if user:
 
            log.debug('password reset user found %s' % user)
 
            link = h.canonical_url('reset_password_confirmation', key=user.api_key)
 
            reg_type = EmailNotificationModel.TYPE_PASSWORD_RESET
 
            body = EmailNotificationModel().get_email_tmpl(reg_type,
 
                                                           'txt',
 
                                                           user=user.short_contact,
 
                                                           reset_url=link)
 
            html_body = EmailNotificationModel().get_email_tmpl(reg_type,
 
                                                           'html',
 
                                                           user=user.short_contact,
 
                                                           reset_url=link)
 
            log.debug('sending email')
 
            run_task(tasks.send_email, [user_email],
 
                     _("Password reset link"), body, html_body)
 
            log.info('send new password mail to %s' % user_email)
 
        else:
 
            log.debug("password reset email %s not found" % user_email)
 

	
 
        return True
 

	
 
    def reset_password(self, data):
 
        from kallithea.lib.celerylib import tasks, run_task
 
        from kallithea.lib import auth
 
        user_email = data['email']
 
        pre_db = True
 
        try:
 
            user = User.get_by_email(user_email)
 
            new_passwd = auth.PasswordGenerator().gen_password(8,
 
                            auth.PasswordGenerator.ALPHABETS_BIG_SMALL)
 
            if user:
 
                user.password = auth.get_crypt_password(new_passwd)
 
                Session().add(user)
 
                Session().commit()
 
                log.info('change password for %s' % user_email)
 
            if new_passwd is None:
 
                raise Exception('unable to generate new password')
 
        user = User.get_by_email(user_email)
 
        new_passwd = auth.PasswordGenerator().gen_password(8,
 
                        auth.PasswordGenerator.ALPHABETS_BIG_SMALL)
 
        if user:
 
            user.password = auth.get_crypt_password(new_passwd)
 
            Session().add(user)
 
            Session().commit()
 
            log.info('change password for %s' % user_email)
 
        if new_passwd is None:
 
            raise Exception('unable to generate new password')
 

	
 
            pre_db = False
 
            run_task(tasks.send_email, [user_email],
 
                     _('Your new password'),
 
                     _('Your new Kallithea password:%s') % (new_passwd,))
 
            log.info('send new password mail to %s' % user_email)
 

	
 
        except Exception:
 
            log.error('Failed to update user password')
 
            log.error(traceback.format_exc())
 
            if pre_db:
 
                # we rollback only if local db stuff fails. If it goes into
 
                # run_task, we're pass rollback state this wouldn't work then
 
                Session().rollback()
 
        run_task(tasks.send_email, [user_email],
 
                 _('Your new password'),
 
                 _('Your new Kallithea password:%s') % (new_passwd,))
 
        log.info('send new password mail to %s' % user_email)
 

	
 
        return True
 

	
 
    def fill_data(self, auth_user, user_id=None, api_key=None, username=None):
 
        """
 
        Fetches auth_user by user_id,or api_key if present.
 
        Fills auth_user attributes with those taken from database.
 
        Additionally sets is_authenitated if lookup fails
 
        present in database
 

	
 
        :param auth_user: instance of user to set attributes
 
        :param user_id: user id to fetch by
 
        :param api_key: api key to fetch by
 
        :param username: username to fetch by
 
        """
 
        if user_id is None and api_key is None and username is None:
 
            raise Exception('You need to pass user_id, api_key or username')
 

	
 
        try:
 
            dbuser = None
 
            if user_id:
 
                dbuser = self.get(user_id)
 
            elif api_key:
 
                dbuser = self.get_by_api_key(api_key)
 
            elif username:
 
                dbuser = self.get_by_username(username)
 
        dbuser = None
 
        if user_id:
 
            dbuser = self.get(user_id)
 
        elif api_key:
 
            dbuser = self.get_by_api_key(api_key)
 
        elif username:
 
            dbuser = self.get_by_username(username)
 

	
 
            if dbuser is not None and dbuser.active:
 
                log.debug('filling %s data' % dbuser)
 
                for k, v in dbuser.get_dict().iteritems():
 
                    if k not in ['api_keys', 'permissions']:
 
                        setattr(auth_user, k, v)
 
            else:
 
                return False
 

	
 
        except Exception:
 
            log.error(traceback.format_exc())
 
            auth_user.is_authenticated = False
 
            return False
 

	
 
        return True
 
        if dbuser is not None and dbuser.active:
 
            log.debug('filling %s data' % dbuser)
 
            for k, v in dbuser.get_dict().iteritems():
 
                if k not in ['api_keys', 'permissions']:
 
                    setattr(auth_user, k, v)
 
            return True
 
        return False
 

	
 
    def has_perm(self, user, perm):
 
        perm = self._get_perm(perm)
 
        user = self._get_user(user)
 

	
 
        return UserToPerm.query().filter(UserToPerm.user == user)\
 
            .filter(UserToPerm.permission == perm).scalar() is not None
 

	
 
    def grant_perm(self, user, perm):
 
        """
 
        Grant user global permissions
 

	
 
        :param user:
 
        :param perm:
 
        """
 
        user = self._get_user(user)
 
        perm = self._get_perm(perm)
 
        # if this permission is already granted skip it
 
        _perm = UserToPerm.query()\
 
            .filter(UserToPerm.user == user)\
 
            .filter(UserToPerm.permission == perm)\
 
            .scalar()
 
        if _perm:
 
            return
 
        new = UserToPerm()
 
        new.user = user
 
        new.permission = perm
 
        self.sa.add(new)
 
        return new
 

	
 
    def revoke_perm(self, user, perm):
 
        """
 
        Revoke users global permissions
 

	
 
        :param user:
 
        :param perm:
 
        """
 
        user = self._get_user(user)
 
        perm = self._get_perm(perm)
 

	
 
        obj = UserToPerm.query()\
 
                .filter(UserToPerm.user == user)\
 
                .filter(UserToPerm.permission == perm)\
 
                .scalar()
 
        if obj:
 
            self.sa.delete(obj)
 

	
 
    def add_extra_email(self, user, email):
 
        """
 
        Adds email address to UserEmailMap
 

	
 
        :param user:
 
        :param email:
 
        """
 
        from kallithea.model import forms
 
        form = forms.UserExtraEmailForm()()
 
        data = form.to_python(dict(email=email))
 
        user = self._get_user(user)
 

	
 
        obj = UserEmailMap()
 
        obj.user = user
 
        obj.email = data['email']
 
        self.sa.add(obj)
 
        return obj
 

	
 
    def delete_extra_email(self, user, email_id):
 
        """
 
        Removes email address from UserEmailMap
 

	
 
        :param user:
 
        :param email_id:
 
        """
 
        user = self._get_user(user)
 
        obj = UserEmailMap.query().get(email_id)
 
        if obj:
 
            self.sa.delete(obj)
 

	
 
    def add_extra_ip(self, user, ip):
 
        """
 
        Adds ip address to UserIpMap
 

	
 
        :param user:
 
        :param ip:
 
        """
 
        from kallithea.model import forms
 
        form = forms.UserExtraIpForm()()
 
        data = form.to_python(dict(ip=ip))
 
        user = self._get_user(user)
 

	
 
        obj = UserIpMap()
 
        obj.user = user
 
        obj.ip_addr = data['ip']
 
        self.sa.add(obj)
 
        return obj
 

	
 
    def delete_extra_ip(self, user, ip_id):
 
        """
 
        Removes ip address from UserIpMap
 

	
 
        :param user:
 
        :param ip_id:
 
        """
 
        user = self._get_user(user)
 
        obj = UserIpMap.query().get(ip_id)
 
        if obj:
 
            self.sa.delete(obj)
kallithea/model/validators.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
Set of generic validators
 
"""
 

	
 
import os
 
import re
 
import formencode
 
import logging
 
from collections import defaultdict
 
from pylons.i18n.translation import _
 
from webhelpers.pylonslib.secure_form import authentication_token
 
import sqlalchemy
 

	
 
from formencode.validators import (
 
    UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set,
 
    NotEmpty, IPAddress, CIDR, String, FancyValidator
 
)
 
from kallithea.lib.compat import OrderedSet
 
from kallithea.lib import ipaddr
 
from kallithea.lib.utils import repo_name_slug
 
from kallithea.lib.utils2 import safe_int, str2bool, aslist
 
from kallithea.model.db import RepoGroup, Repository, UserGroup, User,\
 
    ChangesetStatus
 
from kallithea.lib.exceptions import LdapImportError
 
from kallithea.config.routing import ADMIN_PREFIX
 
from kallithea.lib.auth import HasRepoGroupPermissionAny, HasPermissionAny
 

	
 
# silence warnings and pylint
 
UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set, \
 
    NotEmpty, IPAddress, CIDR, String, FancyValidator
 

	
 
log = logging.getLogger(__name__)
 

	
 
class _Missing(object):
 
    pass
 

	
 
Missing = _Missing()
 

	
 

	
 
class StateObj(object):
 
    """
 
    this is needed to translate the messages using _() in validators
 
    """
 
    _ = staticmethod(_)
 

	
 

	
 
def M(self, key, state=None, **kwargs):
 
    """
 
    returns string from self.message based on given key,
 
    passed kw params are used to substitute %(named)s params inside
 
    translated strings
 

	
 
    :param msg:
 
    :param state:
 
    """
 
    if state is None:
 
        state = StateObj()
 
    else:
 
        state._ = staticmethod(_)
 
    #inject validator into state object
 
    return self.message(key, state, **kwargs)
 

	
 

	
 
def UniqueListFromString():
 
    class _UniqueListFromString(formencode.FancyValidator):
 
        """
 
        Split value on ',' and make unique while preserving order
 
        """
 
        messages = dict(
 
            empty=_('Value cannot be an empty list'),
 
            missing_value=_('Value cannot be an empty list'),
 
        )
 

	
 
        def _to_python(self, value, state):
 
            value = aslist(value, ',')
 
            seen = set()
 
            return [c for c in value if not (c in seen or seen.add(c))]
 

	
 
        def empty_value(self, value):
 
            return []
 

	
 
    return _UniqueListFromString
 

	
 

	
 
def ValidUsername(edit=False, old_data={}):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'username_exists': _(u'Username "%(username)s" already exists'),
 
            'system_invalid_username':
 
                _(u'Username "%(username)s" is forbidden'),
 
            'invalid_username':
 
                _(u'Username may only contain alphanumeric characters '
 
                    'underscores, periods or dashes and must begin with '
 
                    'alphanumeric character or underscore')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            if value in ['default', 'new_user']:
 
                msg = M(self, 'system_invalid_username', state, username=value)
 
                raise formencode.Invalid(msg, value, state)
 
            #check if user is unique
 
            old_un = None
 
            if edit:
 
                old_un = User.get(old_data.get('user_id')).username
 

	
 
            if old_un != value or not edit:
 
                if User.get_by_username(value, case_insensitive=True):
 
                    msg = M(self, 'username_exists', state, username=value)
 
                    raise formencode.Invalid(msg, value, state)
 

	
 
            if re.match(r'^[a-zA-Z0-9\_]{1}[a-zA-Z0-9\-\_\.]*$', value) is None:
 
                msg = M(self, 'invalid_username', state)
 
                raise formencode.Invalid(msg, value, state)
 
    return _validator
 

	
 

	
 
def ValidRegex(msg=None):
 
    class _validator(formencode.validators.Regex):
 
        messages = dict(invalid=msg or _('The input is not valid'))
 
    return _validator
 

	
 

	
 
def ValidRepoUser():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_username': _(u'Username %(username)s is not valid')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            try:
 
                User.query().filter(User.active == True)\
 
                    .filter(User.username == value).one()
 
            except Exception:
 
            except sqlalchemy.exc.InvalidRequestError: # NoResultFound/MultipleResultsFound
 
                msg = M(self, 'invalid_username', state, username=value)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(username=msg)
 
                )
 

	
 
    return _validator
 

	
 

	
 
def ValidUserGroup(edit=False, old_data={}):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_group': _(u'Invalid user group name'),
 
            'group_exist': _(u'User group "%(usergroup)s" already exists'),
 
            'invalid_usergroup_name':
 
                _(u'user group name may only contain alphanumeric '
 
                  'characters underscores, periods or dashes and must begin '
 
                  'with alphanumeric character')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            if value in ['default']:
 
                msg = M(self, 'invalid_group', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(users_group_name=msg)
 
                )
 
            #check if group is unique
 
            old_ugname = None
 
            if edit:
 
                old_id = old_data.get('users_group_id')
 
                old_ugname = UserGroup.get(old_id).users_group_name
 

	
 
            if old_ugname != value or not edit:
 
                is_existing_group = UserGroup.get_by_group_name(value,
 
                                                        case_insensitive=True)
 
                if is_existing_group:
 
                    msg = M(self, 'group_exist', state, usergroup=value)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(users_group_name=msg)
 
                    )
 

	
 
            if re.match(r'^[a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+$', value) is None:
 
                msg = M(self, 'invalid_usergroup_name', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(users_group_name=msg)
 
                )
 

	
 
    return _validator
 

	
 

	
 
def ValidRepoGroup(edit=False, old_data={}):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'group_parent_id': _(u'Cannot assign this group as parent'),
 
            'group_exists': _(u'Group "%(group_name)s" already exists'),
 
            'repo_exists':
 
                _(u'Repository with name "%(group_name)s" already exists')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            # TODO WRITE VALIDATIONS
 
            group_name = value.get('group_name')
 
            group_parent_id = value.get('group_parent_id')
 

	
 
            # slugify repo group just in case :)
 
            slug = repo_name_slug(group_name)
 

	
 
            # check for parent of self
 
            parent_of_self = lambda: (
 
                old_data['group_id'] == int(group_parent_id)
 
                if group_parent_id else False
 
            )
 
            if edit and parent_of_self():
 
                msg = M(self, 'group_parent_id', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(group_parent_id=msg)
 
                )
 

	
 
            old_gname = None
 
            if edit:
 
                old_gname = RepoGroup.get(old_data.get('group_id')).group_name
 

	
 
            if old_gname != group_name or not edit:
 

	
 
                # check group
 
                gr = RepoGroup.query()\
 
                      .filter(RepoGroup.group_name == slug)\
 
                      .filter(RepoGroup.group_parent_id == group_parent_id)\
 
                      .scalar()
 

	
 
                if gr:
 
                    msg = M(self, 'group_exists', state, group_name=slug)
 
                    raise formencode.Invalid(msg, value, state,
 
                            error_dict=dict(group_name=msg)
 
                    )
 

	
 
                # check for same repo
 
                repo = Repository.query()\
 
                      .filter(Repository.repo_name == slug)\
 
                      .scalar()
 

	
 
                if repo:
 
                    msg = M(self, 'repo_exists', state, group_name=slug)
 
                    raise formencode.Invalid(msg, value, state,
 
                            error_dict=dict(group_name=msg)
 
                    )
 

	
 
    return _validator
 

	
 

	
 
def ValidPassword():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_password':
 
                _(u'Invalid characters (non-ascii) in password')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            try:
 
                (value or '').decode('ascii')
 
            except UnicodeError:
 
                msg = M(self, 'invalid_password', state)
 
                raise formencode.Invalid(msg, value, state,)
 
    return _validator
 

	
 

	
 
def ValidOldPassword(username):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_password': _(u'Invalid old password')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            from kallithea.lib import auth_modules
 
            if not auth_modules.authenticate(username, value, ''):
 
                msg = M(self, 'invalid_password', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(current_password=msg)
 
                )
 
    return _validator
 

	
 

	
 
def ValidPasswordsMatch(passwd='new_password', passwd_confirmation='password_confirmation'):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'password_mismatch': _(u'Passwords do not match'),
 
        }
 

	
 
        def validate_python(self, value, state):
 

	
 
            pass_val = value.get('password') or value.get(passwd)
 
            if pass_val != value[passwd_confirmation]:
 
                msg = M(self, 'password_mismatch', state)
 
                raise formencode.Invalid(msg, value, state,
 
                     error_dict={passwd:msg, passwd_confirmation: msg}
 
                )
 
    return _validator
 

	
 

	
 
def ValidAuth():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_password': _(u'invalid password'),
 
            'invalid_username': _(u'invalid user name'),
 
            'disabled_account': _(u'Your account is disabled')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            from kallithea.lib import auth_modules
 

	
 
            password = value['password']
 
            username = value['username']
 

	
 
            if not auth_modules.authenticate(username, password):
 
                user = User.get_by_username(username)
 
                if user and not user.active:
 
                    log.warning('user %s is disabled' % username)
 
                    msg = M(self, 'disabled_account', state)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(username=msg)
 
                    )
 
                else:
 
                    log.warning('user %s failed to authenticate' % username)
 
                    msg = M(self, 'invalid_username', state)
 
                    msg2 = M(self, 'invalid_password', state)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(username=msg, password=msg2)
 
                    )
 
    return _validator
 

	
 

	
 
def ValidAuthToken():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_token': _(u'Token mismatch')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            if value != authentication_token():
 
                msg = M(self, 'invalid_token', state)
 
                raise formencode.Invalid(msg, value, state)
 
    return _validator
 

	
 

	
 
def ValidRepoName(edit=False, old_data={}):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_repo_name':
 
                _(u'Repository name %(repo)s is disallowed'),
 
            'repository_exists':
 
                _(u'Repository named %(repo)s already exists'),
 
            'repository_in_group_exists': _(u'Repository "%(repo)s" already '
 
                                            'exists in group "%(group)s"'),
 
            'same_group_exists': _(u'Repository group with name "%(repo)s" '
 
                                   'already exists')
 
        }
 

	
 
        def _to_python(self, value, state):
 
            repo_name = repo_name_slug(value.get('repo_name', ''))
 
            repo_group = value.get('repo_group')
 
            if repo_group:
 
                gr = RepoGroup.get(repo_group)
 
                group_path = gr.full_path
 
                group_name = gr.group_name
 
                # value needs to be aware of group name in order to check
 
                # db key This is an actual just the name to store in the
 
                # database
 
                repo_name_full = group_path + RepoGroup.url_sep() + repo_name
 
            else:
 
                group_name = group_path = ''
 
                repo_name_full = repo_name
 

	
 
            value['repo_name'] = repo_name
 
            value['repo_name_full'] = repo_name_full
 
            value['group_path'] = group_path
 
            value['group_name'] = group_name
 
            return value
 

	
 
        def validate_python(self, value, state):
 

	
 
            repo_name = value.get('repo_name')
 
            repo_name_full = value.get('repo_name_full')
 
            group_path = value.get('group_path')
 
            group_name = value.get('group_name')
 

	
 
            if repo_name in [ADMIN_PREFIX, '']:
 
                msg = M(self, 'invalid_repo_name', state, repo=repo_name)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(repo_name=msg)
 
                )
 

	
 
            rename = old_data.get('repo_name') != repo_name_full
 
            create = not edit
 
            if rename or create:
 

	
 
                if group_path != '':
 
                    if Repository.get_by_repo_name(repo_name_full):
 
                        msg = M(self, 'repository_in_group_exists', state,
 
                                repo=repo_name, group=group_name)
 
                        raise formencode.Invalid(msg, value, state,
 
                            error_dict=dict(repo_name=msg)
 
                        )
 
                elif RepoGroup.get_by_group_name(repo_name_full):
 
                        msg = M(self, 'same_group_exists', state,
 
                                repo=repo_name)
 
                        raise formencode.Invalid(msg, value, state,
 
                            error_dict=dict(repo_name=msg)
 
                        )
 

	
 
                elif Repository.get_by_repo_name(repo_name_full):
 
                        msg = M(self, 'repository_exists', state,
 
                                repo=repo_name)
 
                        raise formencode.Invalid(msg, value, state,
 
                            error_dict=dict(repo_name=msg)
 
                        )
 
            return value
 
    return _validator
 

	
 

	
 
def ValidForkName(*args, **kwargs):
 
    return ValidRepoName(*args, **kwargs)
 

	
 

	
 
def SlugifyName():
 
    class _validator(formencode.validators.FancyValidator):
 

	
 
        def _to_python(self, value, state):
 
            return repo_name_slug(value)
 

	
 
        def validate_python(self, value, state):
 
            pass
 

	
 
    return _validator
 

	
 

	
 
def ValidCloneUri():
 
    from kallithea.lib.utils import make_ui
 

	
 
    def url_handler(repo_type, url, ui):
 
        if repo_type == 'hg':
 
            from kallithea.lib.vcs.backends.hg.repository import MercurialRepository
 
            if url.startswith('http'):
 
                # initially check if it's at least the proper URL
 
                # or does it pass basic auth
 
                MercurialRepository._check_url(url, ui)
 
            elif url.startswith('svn+http'):
 
                from hgsubversion.svnrepo import svnremoterepo
 
                svnremoterepo(ui, url).svn.uuid
 
            elif url.startswith('git+http'):
 
                raise NotImplementedError()
 
            else:
 
                raise Exception('clone from URI %s not allowed' % (url,))
 

	
 
        elif repo_type == 'git':
 
            from kallithea.lib.vcs.backends.git.repository import GitRepository
 
            if url.startswith('http'):
 
                # initially check if it's at least the proper URL
 
                # or does it pass basic auth
 
                GitRepository._check_url(url)
 
            elif url.startswith('svn+http'):
 
                raise NotImplementedError()
 
            elif url.startswith('hg+http'):
 
                raise NotImplementedError()
 
            else:
 
                raise Exception('clone from URI %s not allowed' % (url))
 

	
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'clone_uri': _(u'invalid clone url'),
 
            'invalid_clone_uri': _(u'Invalid clone url, provide a '
 
                                    'valid clone http(s)/svn+http(s) url')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            repo_type = value.get('repo_type')
 
            url = value.get('clone_uri')
 

	
 
            if not url:
 
                pass
 
            else:
 
                try:
 
                    url_handler(repo_type, url, make_ui('db', clear_session=False))
 
                except Exception:
 
                    log.exception('Url validation failed')
 
                    msg = M(self, 'clone_uri')
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(clone_uri=msg)
 
                    )
 
    return _validator
 

	
 

	
 
def ValidForkType(old_data={}):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_fork_type': _(u'Fork have to be the same type as parent')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            if old_data['repo_type'] != value:
 
                msg = M(self, 'invalid_fork_type', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(repo_type=msg)
 
                )
 
    return _validator
 

	
 

	
 
def CanWriteGroup(old_data=None):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'permission_denied': _(u"You don't have permissions "
 
                                   "to create repository in this group"),
 
            'permission_denied_root': _(u"no permission to create repository "
 
                                        "in root location")
 
        }
 

	
 
        def _to_python(self, value, state):
 
            #root location
 
            if value in [-1, "-1"]:
 
                return None
 
            return value
 

	
 
        def validate_python(self, value, state):
 
            gr = RepoGroup.get(value)
 
            gr_name = gr.group_name if gr else None  # None means ROOT location
 
            # create repositories with write permission on group is set to true
 
            create_on_write = HasPermissionAny('hg.create.write_on_repogroup.true')()
 
            group_admin = HasRepoGroupPermissionAny('group.admin')(gr_name,
 
                                            'can write into group validator')
 
            group_write = HasRepoGroupPermissionAny('group.write')(gr_name,
 
                                            'can write into group validator')
 
            forbidden = not (group_admin or (group_write and create_on_write))
 
            can_create_repos = HasPermissionAny('hg.admin', 'hg.create.repository')
 
            gid = (old_data['repo_group'].get('group_id')
 
                   if (old_data and 'repo_group' in old_data) else None)
 
            value_changed = gid != safe_int(value)
 
            new = not old_data
 
            # do check if we changed the value, there's a case that someone got
 
            # revoked write permissions to a repository, he still created, we
 
            # don't need to check permission if he didn't change the value of
 
            # groups in form box
 
            if value_changed or new:
 
                #parent group need to be existing
 
                if gr and forbidden:
 
                    msg = M(self, 'permission_denied', state)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(repo_type=msg)
 
                    )
 
                ## check if we can write to root location !
 
                elif gr is None and not can_create_repos():
 
                    msg = M(self, 'permission_denied_root', state)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(repo_type=msg)
 
                    )
 

	
 
    return _validator
 

	
 

	
 
def CanCreateGroup(can_create_in_root=False):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'permission_denied': _(u"You don't have permissions "
 
                                   "to create a group in this location")
 
        }
 

	
 
        def to_python(self, value, state):
 
            #root location
 
            if value in [-1, "-1"]:
 
                return None
 
            return value
 

	
 
        def validate_python(self, value, state):
 
            gr = RepoGroup.get(value)
 
            gr_name = gr.group_name if gr else None  # None means ROOT location
 

	
 
            if can_create_in_root and gr is None:
 
                #we can create in root, we're fine no validations required
 
                return
 

	
 
            forbidden_in_root = gr is None and not can_create_in_root
 
            val = HasRepoGroupPermissionAny('group.admin')
 
            forbidden = not val(gr_name, 'can create group validator')
 
            if forbidden_in_root or forbidden:
 
                msg = M(self, 'permission_denied', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(group_parent_id=msg)
 
                )
 

	
 
    return _validator
 

	
 

	
 
def ValidPerms(type_='repo'):
 
    if type_ == 'repo_group':
 
        EMPTY_PERM = 'group.none'
 
    elif type_ == 'repo':
 
        EMPTY_PERM = 'repository.none'
 
    elif type_ == 'user_group':
 
        EMPTY_PERM = 'usergroup.none'
 

	
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'perm_new_member_name':
 
                _(u'This username or user group name is not valid')
 
        }
 

	
 
        def to_python(self, value, state):
 
            perms_update = OrderedSet()
 
            perms_new = OrderedSet()
 
            # build a list of permission to update and new permission to create
 

	
 
            #CLEAN OUT ORG VALUE FROM NEW MEMBERS, and group them using
 
            new_perms_group = defaultdict(dict)
 
            for k, v in value.copy().iteritems():
 
                if k.startswith('perm_new_member'):
 
                    del value[k]
 
                    _type, part = k.split('perm_new_member_')
 
                    args = part.split('_')
 
                    if len(args) == 1:
 
                        new_perms_group[args[0]]['perm'] = v
 
                    elif len(args) == 2:
 
                        _key, pos = args
 
                        new_perms_group[pos][_key] = v
 

	
 
            # fill new permissions in order of how they were added
 
            for k in sorted(map(int, new_perms_group.keys())):
 
                perm_dict = new_perms_group[str(k)]
 
                new_member = perm_dict.get('name')
 
                new_perm = perm_dict.get('perm')
 
                new_type = perm_dict.get('type')
 
                if new_member and new_perm and new_type:
 
                    perms_new.add((new_member, new_perm, new_type))
 

	
 
            for k, v in value.iteritems():
 
                if k.startswith('u_perm_') or k.startswith('g_perm_'):
 
                    member = k[7:]
 
                    t = {'u': 'user',
 
                         'g': 'users_group'
 
                    }[k[0]]
 
                    if member == User.DEFAULT_USER:
 
                        if str2bool(value.get('repo_private')):
 
                            # set none for default when updating to
 
                            # private repo protects agains form manipulation
 
                            v = EMPTY_PERM
 
                    perms_update.add((member, v, t))
 

	
 
            value['perms_updates'] = list(perms_update)
 
            value['perms_new'] = list(perms_new)
 

	
 
            # update permissions
 
            for k, v, t in perms_new:
 
                try:
 
                    if t is 'user':
 
                        self.user_db = User.query()\
 
                            .filter(User.active == True)\
 
                            .filter(User.username == k).one()
 
                    if t is 'users_group':
 
                        self.user_db = UserGroup.query()\
 
                            .filter(UserGroup.users_group_active == True)\
 
                            .filter(UserGroup.users_group_name == k).one()
 

	
 
                except Exception:
 
                    log.exception('Updated permission failed')
 
                    msg = M(self, 'perm_new_member_type', state)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(perm_new_member_name=msg)
 
                    )
 
            return value
 
    return _validator
 

	
 

	
 
def ValidSettings():
 
    class _validator(formencode.validators.FancyValidator):
 
        def _to_python(self, value, state):
 
            # settings  form for users that are not admin
 
            # can't edit certain parameters, it's extra backup if they mangle
 
            # with forms
 

	
 
            forbidden_params = [
 
                'user', 'repo_type', 'repo_enable_locking',
 
                'repo_enable_downloads', 'repo_enable_statistics'
 
            ]
 

	
 
            for param in forbidden_params:
 
                if param in value:
 
                    del value[param]
 
            return value
 

	
 
        def validate_python(self, value, state):
 
            pass
 
    return _validator
 

	
 

	
 
def ValidPath():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'invalid_path': _(u'This is not a valid path')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            if not os.path.isdir(value):
 
                msg = M(self, 'invalid_path', state)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(paths_root_path=msg)
 
                )
 
    return _validator
 

	
 

	
 
def UniqSystemEmail(old_data={}):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'email_taken': _(u'This e-mail address is already taken')
 
        }
 

	
 
        def _to_python(self, value, state):
 
            return value.lower()
 

	
 
        def validate_python(self, value, state):
 
            if (old_data.get('email') or '').lower() != value:
 
                user = User.get_by_email(value, case_insensitive=True)
 
                if user:
 
                    msg = M(self, 'email_taken', state)
 
                    raise formencode.Invalid(msg, value, state,
 
                        error_dict=dict(email=msg)
 
                    )
 
    return _validator
 

	
 

	
 
def ValidSystemEmail():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'non_existing_email': _(u'e-mail "%(email)s" does not exist.')
 
        }
 

	
 
        def _to_python(self, value, state):
 
            return value.lower()
 

	
 
        def validate_python(self, value, state):
 
            user = User.get_by_email(value, case_insensitive=True)
 
            if user is None:
 
                msg = M(self, 'non_existing_email', state, email=value)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(email=msg)
 
                )
 

	
 
    return _validator
 

	
 

	
 
def LdapLibValidator():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 

	
 
        }
 

	
 
        def validate_python(self, value, state):
 
            try:
 
                import ldap
 
                ldap  # pyflakes silence !
 
            except ImportError:
 
                raise LdapImportError()
 

	
 
    return _validator
 

	
 

	
 
def AttrLoginValidator():
 
    class _validator(formencode.validators.UnicodeString):
 
        messages = {
 
            'invalid_cn':
 
                  _(u'The LDAP Login attribute of the CN must be specified - '
 
                    'this is the name of the attribute that is equivalent '
 
                    'to "username"')
 
        }
 
        messages['empty'] = messages['invalid_cn']
 

	
 
    return _validator
 

	
 

	
 
def NotReviewedRevisions(repo_id):
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = {
 
            'rev_already_reviewed':
 
                  _(u'Revisions %(revs)s are already part of pull request '
 
                    'or have set status')
 
        }
 

	
 
        def validate_python(self, value, state):
 
            # check revisions if they are not reviewed, or a part of another
 
            # pull request
 
            statuses = ChangesetStatus.query()\
 
                .filter(ChangesetStatus.revision.in_(value))\
 
                .filter(ChangesetStatus.repo_id == repo_id)\
 
                .all()
 

	
 
            errors = []
 
            for cs in statuses:
 
                if cs.pull_request_id:
 
                    errors.append(['pull_req', cs.revision[:12]])
 
                elif cs.status:
 
                    errors.append(['status', cs.revision[:12]])
 

	
 
            if errors:
 
                revs = ','.join([x[1] for x in errors])
 
                msg = M(self, 'rev_already_reviewed', state, revs=revs)
 
                raise formencode.Invalid(msg, value, state,
 
                    error_dict=dict(revisions=revs)
 
                )
 

	
 
    return _validator
 

	
 

	
 
def ValidIp():
 
    class _validator(CIDR):
 
        messages = dict(
 
            badFormat=_('Please enter a valid IPv4 or IpV6 address'),
 
            illegalBits=_('The network size (bits) must be within the range'
 
                ' of 0-32 (not %(bits)r)')
 
        )
 

	
 
        def to_python(self, value, state):
 
            v = super(_validator, self).to_python(value, state)
 
            v = v.strip()
 
            net = ipaddr.IPNetwork(address=v)
 
            if isinstance(net, ipaddr.IPv4Network):
 
                #if IPv4 doesn't end with a mask, add /32
 
                if '/' not in value:
 
                    v += '/32'
 
            if isinstance(net, ipaddr.IPv6Network):
 
                #if IPv6 doesn't end with a mask, add /128
 
                if '/' not in value:
 
                    v += '/128'
 
            return v
 

	
 
        def validate_python(self, value, state):
 
            try:
 
                addr = value.strip()
 
                #this raises an ValueError if address is not IpV4 or IpV6
 
                ipaddr.IPNetwork(address=addr)
 
            except ValueError:
 
                raise formencode.Invalid(self.message('badFormat', state),
 
                                         value, state)
 

	
 
    return _validator
 

	
 

	
 
def FieldKey():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = dict(
 
            badFormat=_('Key name can only consist of letters, '
 
                        'underscore, dash or numbers')
 
        )
 

	
 
        def validate_python(self, value, state):
 
            if not re.match('[a-zA-Z0-9_-]+$', value):
 
                raise formencode.Invalid(self.message('badFormat', state),
 
                                         value, state)
 
    return _validator
 

	
 

	
 
def BasePath():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = dict(
 
            badPath=_('Filename cannot be inside a directory')
 
        )
 

	
 
        def _to_python(self, value, state):
 
            return value
 

	
 
        def validate_python(self, value, state):
 
            if value != os.path.basename(value):
 
                raise formencode.Invalid(self.message('badPath', state),
 
                                         value, state)
 
    return _validator
 

	
 

	
 
def ValidAuthPlugins():
 
    class _validator(formencode.validators.FancyValidator):
 
        messages = dict(
 
            import_duplicate=_('Plugins %(loaded)s and %(next_to_load)s both export the same name')
 
        )
 

	
 
        def _to_python(self, value, state):
 
            # filter empty values
 
            return filter(lambda s: s not in [None, ''], value)
 

	
 
        def validate_python(self, value, state):
 
            from kallithea.lib import auth_modules
 
            module_list = value
 
            unique_names = {}
 
            try:
 
                for module in module_list:
 
                    plugin = auth_modules.loadplugin(module)
 
                    plugin_name = plugin.name
 
                    if plugin_name in unique_names:
 
                        msg = M(self, 'import_duplicate', state,
 
                                loaded=unique_names[plugin_name],
 
                                next_to_load=plugin_name)
 
                        raise formencode.Invalid(msg, value, state)
 
                    unique_names[plugin_name] = plugin
 
            except (ImportError, AttributeError, TypeError), e:
 
                raise formencode.Invalid(str(e), value, state)
 

	
 
    return _validator
kallithea/tests/functional/test_admin_repos.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
import os
 
import mock
 
import urllib
 

	
 
from kallithea.lib import vcs
 
from kallithea.model.db import Repository, RepoGroup, UserRepoToPerm, User,\
 
    Permission
 
from kallithea.model.user import UserModel
 
from kallithea.tests import *
 
from kallithea.model.repo_group import RepoGroupModel
 
from kallithea.model.repo import RepoModel
 
from kallithea.model.meta import Session
 
from kallithea.tests.fixture import Fixture, error_function
 

	
 
fixture = Fixture()
 

	
 

	
 
def _get_permission_for_user(user, repo):
 
    perm = UserRepoToPerm.query()\
 
                .filter(UserRepoToPerm.repository ==
 
                        Repository.get_by_repo_name(repo))\
 
                .filter(UserRepoToPerm.user == User.get_by_username(user))\
 
                .all()
 
    return perm
 

	
 

	
 
class _BaseTest(TestController):
 
    """
 
    Write all tests here
 
    """
 
    REPO = None
 
    REPO_TYPE = None
 
    NEW_REPO = None
 
    OTHER_TYPE_REPO = None
 
    OTHER_TYPE = None
 

	
 
    @classmethod
 
    def setup_class(cls):
 
        pass
 

	
 
    @classmethod
 
    def teardown_class(cls):
 
        pass
 

	
 
    def test_index(self):
 
        self.log_user()
 
        response = self.app.get(url('repos'))
 

	
 
    def test_create(self):
 
        self.log_user()
 
        repo_name = self.NEW_REPO
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description))
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name))
 
        self.assertEqual(response.json, {u'result': True})
 
        self.checkSessionFlash(response,
 
                               'Created repository <a href="/%s">%s</a>'
 
                               % (repo_name, repo_name))
 

	
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name).one()
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name)
 
        self.assertEqual(new_repo.description, description)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name))
 
        response.mustcontain(repo_name)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
        RepoModel().delete(repo_name)
 
        Session().commit()
 

	
 
    def test_create_non_ascii(self):
 
        self.log_user()
 
        non_ascii = "ąęł"
 
        repo_name = "%s%s" % (self.NEW_REPO, non_ascii)
 
        repo_name_unicode = repo_name.decode('utf8')
 
        description = 'description for newly created repo' + non_ascii
 
        description_unicode = description.decode('utf8')
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description))
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name))
 
        self.assertEqual(response.json, {u'result': True})
 
        self.checkSessionFlash(response,
 
                               u'Created repository <a href="/%s">%s</a>'
 
                               % (urllib.quote(repo_name), repo_name_unicode))
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name_unicode).one()
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name_unicode)
 
        self.assertEqual(new_repo.description, description_unicode)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name))
 
        response.mustcontain(repo_name)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
    def test_create_in_group(self):
 
        self.log_user()
 

	
 
        ## create GROUP
 
        group_name = 'sometest_%s' % self.REPO_TYPE
 
        gr = RepoGroupModel().create(group_name=group_name,
 
                                     group_description='test',
 
                                     owner=TEST_USER_ADMIN_LOGIN)
 
        Session().commit()
 

	
 
        repo_name = 'ingroup'
 
        repo_name_full = RepoGroup.url_sep().join([group_name, repo_name])
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description,
 
                                                repo_group=gr.group_id,))
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name_full))
 
        self.assertEqual(response.json, {u'result': True})
 
        self.checkSessionFlash(response,
 
                               'Created repository <a href="/%s">%s</a>'
 
                               % (repo_name_full, repo_name_full))
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name_full).one()
 
        new_repo_id = new_repo.repo_id
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name_full)
 
        self.assertEqual(new_repo.description, description)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name_full))
 
        response.mustcontain(repo_name_full)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        inherited_perms = UserRepoToPerm.query()\
 
            .filter(UserRepoToPerm.repository_id == new_repo_id).all()
 
        self.assertEqual(len(inherited_perms), 1)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name_full))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            RepoGroupModel().delete(group_name)
 
            Session().commit()
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
        RepoModel().delete(repo_name_full)
 
        RepoGroupModel().delete(group_name)
 
        Session().commit()
 

	
 
    def test_create_in_group_without_needed_permissions(self):
 
        usr = self.log_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS)
 
        # revoke
 
        user_model = UserModel()
 
        # disable fork and create on default user
 
        user_model.revoke_perm(User.DEFAULT_USER, 'hg.create.repository')
 
        user_model.grant_perm(User.DEFAULT_USER, 'hg.create.none')
 
        user_model.revoke_perm(User.DEFAULT_USER, 'hg.fork.repository')
 
        user_model.grant_perm(User.DEFAULT_USER, 'hg.fork.none')
 

	
 
        # disable on regular user
 
        user_model.revoke_perm(TEST_USER_REGULAR_LOGIN, 'hg.create.repository')
 
        user_model.grant_perm(TEST_USER_REGULAR_LOGIN, 'hg.create.none')
 
        user_model.revoke_perm(TEST_USER_REGULAR_LOGIN, 'hg.fork.repository')
 
        user_model.grant_perm(TEST_USER_REGULAR_LOGIN, 'hg.fork.none')
 
        Session().commit()
 

	
 
        ## create GROUP
 
        group_name = 'reg_sometest_%s' % self.REPO_TYPE
 
        gr = RepoGroupModel().create(group_name=group_name,
 
                                     group_description='test',
 
                                     owner=TEST_USER_ADMIN_LOGIN)
 
        Session().commit()
 

	
 
        group_name_allowed = 'reg_sometest_allowed_%s' % self.REPO_TYPE
 
        gr_allowed = RepoGroupModel().create(group_name=group_name_allowed,
 
                                     group_description='test',
 
                                     owner=TEST_USER_REGULAR_LOGIN)
 
        Session().commit()
 

	
 
        repo_name = 'ingroup'
 
        repo_name_full = RepoGroup.url_sep().join([group_name, repo_name])
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description,
 
                                                repo_group=gr.group_id,))
 

	
 
        response.mustcontain('Invalid value')
 

	
 
        # user is allowed to create in this group
 
        repo_name = 'ingroup'
 
        repo_name_full = RepoGroup.url_sep().join([group_name_allowed, repo_name])
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description,
 
                                                repo_group=gr_allowed.group_id,))
 

	
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name_full))
 
        self.assertEqual(response.json, {u'result': True})
 
        self.checkSessionFlash(response,
 
                               'Created repository <a href="/%s">%s</a>'
 
                               % (repo_name_full, repo_name_full))
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name_full).one()
 
        new_repo_id = new_repo.repo_id
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name_full)
 
        self.assertEqual(new_repo.description, description)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name_full))
 
        response.mustcontain(repo_name_full)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        inherited_perms = UserRepoToPerm.query()\
 
            .filter(UserRepoToPerm.repository_id == new_repo_id).all()
 
        self.assertEqual(len(inherited_perms), 1)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name_full))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            RepoGroupModel().delete(group_name)
 
            Session().commit()
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
        RepoModel().delete(repo_name_full)
 
        RepoGroupModel().delete(group_name)
 
        RepoGroupModel().delete(group_name_allowed)
 
        Session().commit()
 

	
 
    def test_create_in_group_inherit_permissions(self):
 
        self.log_user()
 

	
 
        ## create GROUP
 
        group_name = 'sometest_%s' % self.REPO_TYPE
 
        gr = RepoGroupModel().create(group_name=group_name,
 
                                     group_description='test',
 
                                     owner=TEST_USER_ADMIN_LOGIN)
 
        perm = Permission.get_by_key('repository.write')
 
        RepoGroupModel().grant_user_permission(gr, TEST_USER_REGULAR_LOGIN, perm)
 

	
 
        ## add repo permissions
 
        Session().commit()
 

	
 
        repo_name = 'ingroup_inherited_%s' % self.REPO_TYPE
 
        repo_name_full = RepoGroup.url_sep().join([group_name, repo_name])
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description,
 
                                                repo_group=gr.group_id,
 
                                                repo_copy_permissions=True))
 

	
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name_full))
 
        self.checkSessionFlash(response,
 
                               'Created repository <a href="/%s">%s</a>'
 
                               % (repo_name_full, repo_name_full))
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name_full).one()
 
        new_repo_id = new_repo.repo_id
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name_full)
 
        self.assertEqual(new_repo.description, description)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name_full))
 
        response.mustcontain(repo_name_full)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name_full))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            RepoGroupModel().delete(group_name)
 
            Session().commit()
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
        #check if inherited permissiona are applied
 
        inherited_perms = UserRepoToPerm.query()\
 
            .filter(UserRepoToPerm.repository_id == new_repo_id).all()
 
        self.assertEqual(len(inherited_perms), 2)
 

	
 
        self.assertTrue(TEST_USER_REGULAR_LOGIN in [x.user.username
 
                                                    for x in inherited_perms])
 
        self.assertTrue('repository.write' in [x.permission.permission_name
 
                                               for x in inherited_perms])
 

	
 
        RepoModel().delete(repo_name_full)
 
        RepoGroupModel().delete(group_name)
 
        Session().commit()
 

	
 
    def test_create_remote_repo_wrong_clone_uri(self):
 
        self.log_user()
 
        repo_name = self.NEW_REPO
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description,
 
                                                clone_uri='http://127.0.0.1/repo'))
 
        response.mustcontain('invalid clone url')
 

	
 

	
 
    def test_create_remote_repo_wrong_clone_uri_hg_svn(self):
 
        self.log_user()
 
        repo_name = self.NEW_REPO
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description,
 
                                                clone_uri='svn+http://127.0.0.1/repo'))
 
        response.mustcontain('invalid clone url')
 

	
 

	
 
    def test_delete(self):
 
        self.log_user()
 
        repo_name = 'vcs_test_new_to_delete_%s' % self.REPO_TYPE
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_name=repo_name,
 
                                                repo_description=description))
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name))
 
        self.checkSessionFlash(response,
 
                               'Created repository <a href="/%s">%s</a>'
 
                               % (repo_name, repo_name))
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name).one()
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name)
 
        self.assertEqual(new_repo.description, description)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name))
 
        response.mustcontain(repo_name)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
        response = self.app.delete(url('repo', repo_name=repo_name))
 

	
 
        self.checkSessionFlash(response, 'Deleted repository %s' % (repo_name))
 

	
 
        response.follow()
 

	
 
        #check if repo was deleted from db
 
        deleted_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name).scalar()
 

	
 
        self.assertEqual(deleted_repo, None)
 

	
 
        self.assertEqual(os.path.isdir(os.path.join(TESTS_TMP_PATH, repo_name)),
 
                                  False)
 

	
 
    def test_delete_non_ascii(self):
 
        self.log_user()
 
        non_ascii = "ąęł"
 
        repo_name = "%s%s" % (self.NEW_REPO, non_ascii)
 
        repo_name_unicode = repo_name.decode('utf8')
 
        description = 'description for newly created repo' + non_ascii
 
        description_unicode = description.decode('utf8')
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description))
 
        ## run the check page that triggers the flash message
 
        response = self.app.get(url('repo_check_home', repo_name=repo_name))
 
        self.assertEqual(response.json, {u'result': True})
 
        self.checkSessionFlash(response,
 
                               u'Created repository <a href="/%s">%s</a>'
 
                               % (urllib.quote(repo_name), repo_name_unicode))
 
        # test if the repo was created in the database
 
        new_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name_unicode).one()
 

	
 
        self.assertEqual(new_repo.repo_name, repo_name_unicode)
 
        self.assertEqual(new_repo.description, description_unicode)
 

	
 
        # test if the repository is visible in the list ?
 
        response = self.app.get(url('summary_home', repo_name=repo_name))
 
        response.mustcontain(repo_name)
 
        response.mustcontain(self.REPO_TYPE)
 

	
 
        # test if the repository was created on filesystem
 
        try:
 
            vcs.get_repo(os.path.join(TESTS_TMP_PATH, repo_name))
 
        except Exception:
 
        except vcs.exceptions.VCSError:
 
            self.fail('no repo %s in filesystem' % repo_name)
 

	
 
        response = self.app.delete(url('repo', repo_name=repo_name))
 
        self.checkSessionFlash(response, 'Deleted repository %s' % (repo_name_unicode))
 
        response.follow()
 

	
 
        #check if repo was deleted from db
 
        deleted_repo = Session().query(Repository)\
 
            .filter(Repository.repo_name == repo_name_unicode).scalar()
 

	
 
        self.assertEqual(deleted_repo, None)
 

	
 
        self.assertEqual(os.path.isdir(os.path.join(TESTS_TMP_PATH, repo_name)),
 
                                  False)
 

	
 
    def test_delete_repo_with_group(self):
 
        #TODO:
 
        pass
 

	
 
    def test_delete_browser_fakeout(self):
 
        response = self.app.post(url('repo', repo_name=self.REPO),
 
                                 params=dict(_method='delete'))
 

	
 
    def test_show(self):
 
        self.log_user()
 
        response = self.app.get(url('repo', repo_name=self.REPO))
 

	
 
    def test_edit(self):
 
        response = self.app.get(url('edit_repo', repo_name=self.REPO))
 

	
 
    def test_set_private_flag_sets_default_to_none(self):
 
        self.log_user()
 
        #initially repository perm should be read
 
        perm = _get_permission_for_user(user='default', repo=self.REPO)
 
        self.assertTrue(len(perm), 1)
 
        self.assertEqual(perm[0].permission.permission_name, 'repository.read')
 
        self.assertEqual(Repository.get_by_repo_name(self.REPO).private, False)
 

	
 
        response = self.app.put(url('repo', repo_name=self.REPO),
 
                        fixture._get_repo_create_params(repo_private=1,
 
                                                repo_name=self.REPO,
 
                                                repo_type=self.REPO_TYPE,
 
                                                user=TEST_USER_ADMIN_LOGIN))
 
        self.checkSessionFlash(response,
 
                               msg='Repository %s updated successfully' % (self.REPO))
 
        self.assertEqual(Repository.get_by_repo_name(self.REPO).private, True)
 

	
 
        #now the repo default permission should be None
 
        perm = _get_permission_for_user(user='default', repo=self.REPO)
 
        self.assertTrue(len(perm), 1)
 
        self.assertEqual(perm[0].permission.permission_name, 'repository.none')
 

	
 
        response = self.app.put(url('repo', repo_name=self.REPO),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=self.REPO,
 
                                                repo_type=self.REPO_TYPE,
 
                                                user=TEST_USER_ADMIN_LOGIN))
 
        self.checkSessionFlash(response,
 
                               msg='Repository %s updated successfully' % (self.REPO))
 
        self.assertEqual(Repository.get_by_repo_name(self.REPO).private, False)
 

	
 
        #we turn off private now the repo default permission should stay None
 
        perm = _get_permission_for_user(user='default', repo=self.REPO)
 
        self.assertTrue(len(perm), 1)
 
        self.assertEqual(perm[0].permission.permission_name, 'repository.none')
 

	
 
        #update this permission back
 
        perm[0].permission = Permission.get_by_key('repository.read')
 
        Session().add(perm[0])
 
        Session().commit()
 

	
 
    def test_set_repo_fork_has_no_self_id(self):
 
        self.log_user()
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        response = self.app.get(url('edit_repo_advanced', repo_name=self.REPO))
 
        opt = """<option value="%s">vcs_test_git</option>""" % repo.repo_id
 
        response.mustcontain(no=[opt])
 

	
 
    def test_set_fork_of_other_repo(self):
 
        self.log_user()
 
        other_repo = 'other_%s' % self.REPO_TYPE
 
        fixture.create_repo(other_repo, repo_type=self.REPO_TYPE)
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        repo2 = Repository.get_by_repo_name(other_repo)
 
        response = self.app.put(url('edit_repo_advanced_fork', repo_name=self.REPO),
 
                                params=dict(id_fork_of=repo2.repo_id))
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        repo2 = Repository.get_by_repo_name(other_repo)
 
        self.checkSessionFlash(response,
 
            'Marked repo %s as fork of %s' % (repo.repo_name, repo2.repo_name))
 

	
 
        assert repo.fork == repo2
 
        response = response.follow()
 
        # check if given repo is selected
 

	
 
        opt = """<option value="%s" selected="selected">%s</option>""" % (
 
                    repo2.repo_id, repo2.repo_name)
 
        response.mustcontain(opt)
 

	
 
        fixture.destroy_repo(other_repo, forks='detach')
 

	
 
    def test_set_fork_of_other_type_repo(self):
 
        self.log_user()
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        repo2 = Repository.get_by_repo_name(self.OTHER_TYPE_REPO)
 
        response = self.app.put(url('edit_repo_advanced_fork', repo_name=self.REPO),
 
                                params=dict(id_fork_of=repo2.repo_id))
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        repo2 = Repository.get_by_repo_name(self.OTHER_TYPE_REPO)
 
        self.checkSessionFlash(response,
 
            'Cannot set repository as fork of repository with other type')
 

	
 
    def test_set_fork_of_none(self):
 
        self.log_user()
 
        ## mark it as None
 
        response = self.app.put(url('edit_repo_advanced_fork', repo_name=self.REPO),
 
                                params=dict(id_fork_of=None))
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        repo2 = Repository.get_by_repo_name(self.OTHER_TYPE_REPO)
 
        self.checkSessionFlash(response,
 
                               'Marked repo %s as fork of %s'
 
                               % (repo.repo_name, "Nothing"))
 
        assert repo.fork is None
 

	
 
    def test_set_fork_of_same_repo(self):
 
        self.log_user()
 
        repo = Repository.get_by_repo_name(self.REPO)
 
        response = self.app.put(url('edit_repo_advanced_fork', repo_name=self.REPO),
 
                                params=dict(id_fork_of=repo.repo_id))
 
        self.checkSessionFlash(response,
 
                               'An error occurred during this operation')
 

	
 
    def test_create_on_top_level_without_permissions(self):
 
        usr = self.log_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS)
 
        # revoke
 
        user_model = UserModel()
 
        # disable fork and create on default user
 
        user_model.revoke_perm(User.DEFAULT_USER, 'hg.create.repository')
 
        user_model.grant_perm(User.DEFAULT_USER, 'hg.create.none')
 
        user_model.revoke_perm(User.DEFAULT_USER, 'hg.fork.repository')
 
        user_model.grant_perm(User.DEFAULT_USER, 'hg.fork.none')
 

	
 
        # disable on regular user
 
        user_model.revoke_perm(TEST_USER_REGULAR_LOGIN, 'hg.create.repository')
 
        user_model.grant_perm(TEST_USER_REGULAR_LOGIN, 'hg.create.none')
 
        user_model.revoke_perm(TEST_USER_REGULAR_LOGIN, 'hg.fork.repository')
 
        user_model.grant_perm(TEST_USER_REGULAR_LOGIN, 'hg.fork.none')
 
        Session().commit()
 

	
 

	
 
        user = User.get(usr['user_id'])
 

	
 
        repo_name = self.NEW_REPO+'no_perms'
 
        description = 'description for newly created repo'
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description))
 

	
 
        response.mustcontain('no permission to create repository in root location')
 

	
 
        RepoModel().delete(repo_name)
 
        Session().commit()
 

	
 
    @mock.patch.object(RepoModel, '_create_filesystem_repo', error_function)
 
    def test_create_repo_when_filesystem_op_fails(self):
 
        self.log_user()
 
        repo_name = self.NEW_REPO
 
        description = 'description for newly created repo'
 

	
 
        response = self.app.post(url('repos'),
 
                        fixture._get_repo_create_params(repo_private=False,
 
                                                repo_name=repo_name,
 
                                                repo_type=self.REPO_TYPE,
 
                                                repo_description=description))
 

	
 
        self.checkSessionFlash(response,
 
                               'Error creating repository %s' % repo_name)
 
        # repo must not be in db
 
        repo = Repository.get_by_repo_name(repo_name)
 
        self.assertEqual(repo, None)
 

	
 
        # repo must not be in filesystem !
 
        self.assertFalse(os.path.isdir(os.path.join(TESTS_TMP_PATH, repo_name)))
 

	
 
class TestAdminReposControllerGIT(_BaseTest):
 
    REPO = GIT_REPO
 
    REPO_TYPE = 'git'
 
    NEW_REPO = NEW_GIT_REPO
 
    OTHER_TYPE_REPO = HG_REPO
 
    OTHER_TYPE = 'hg'
 

	
 

	
 
class TestAdminReposControllerHG(_BaseTest):
 
    REPO = HG_REPO
 
    REPO_TYPE = 'hg'
 
    NEW_REPO = NEW_HG_REPO
 
    OTHER_TYPE_REPO = GIT_REPO
 
    OTHER_TYPE = 'git'
kallithea/tests/scripts/test_concurency.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU General Public License for more details.
 
#
 
# You should have received a copy of the GNU General Public License
 
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
"""
 
kallithea.tests.test_hg_operations
 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

	
 
Test suite for making push/pull operations
 

	
 
This file was forked by the Kallithea project in July 2014.
 
Original author and date, and relevant copyright and licensing information is below:
 
:created_on: Dec 30, 2010
 
:author: marcink
 
:copyright: (c) 2013 RhodeCode GmbH, and others.
 
:license: GPLv3, see LICENSE.md for more details.
 

	
 
"""
 

	
 
import os
 
import sys
 
import shutil
 
import logging
 
from os.path import join as jn
 
from os.path import dirname as dn
 

	
 
from tempfile import _RandomNameSequence
 
from subprocess import Popen, PIPE
 

	
 
from paste.deploy import appconfig
 
from sqlalchemy import engine_from_config
 

	
 
from kallithea.lib.utils import add_cache
 
from kallithea.model import init_model
 
from kallithea.model import meta
 
from kallithea.model.db import User, Repository
 
from kallithea.lib.auth import get_crypt_password
 

	
 
from kallithea.tests import TESTS_TMP_PATH, HG_REPO
 
from kallithea.config.environment import load_environment
 

	
 
rel_path = dn(dn(dn(dn(os.path.abspath(__file__)))))
 
conf = appconfig('config:rc.ini', relative_to=rel_path)
 
load_environment(conf.global_conf, conf.local_conf)
 

	
 
add_cache(conf)
 

	
 
USER = 'test_admin'
 
PASS = 'test12'
 
HOST = 'rc.local'
 
METHOD = 'pull'
 
DEBUG = True
 
log = logging.getLogger(__name__)
 

	
 

	
 
class Command(object):
 

	
 
    def __init__(self, cwd):
 
        self.cwd = cwd
 

	
 
    def execute(self, cmd, *args):
 
        """Runs command on the system with given ``args``.
 
        """
 

	
 
        command = cmd + ' ' + ' '.join(args)
 
        log.debug('Executing %s' % command)
 
        if DEBUG:
 
            print command
 
        p = Popen(command, shell=True, stdout=PIPE, stderr=PIPE, cwd=self.cwd)
 
        stdout, stderr = p.communicate()
 
        if DEBUG:
 
            print stdout, stderr
 
        return stdout, stderr
 

	
 

	
 
def get_session():
 
    engine = engine_from_config(conf, 'sqlalchemy.db1.')
 
    init_model(engine)
 
    sa = meta.Session
 
    return sa
 

	
 

	
 
def create_test_user(force=True):
 
    print 'creating test user'
 
    sa = get_session()
 

	
 
    user = sa.query(User).filter(User.username == USER).scalar()
 

	
 
    if force and user is not None:
 
        print 'removing current user'
 
        for repo in sa.query(Repository).filter(Repository.user == user).all():
 
            sa.delete(repo)
 
        sa.delete(user)
 
        sa.commit()
 

	
 
    if user is None or force:
 
        print 'creating new one'
 
        new_usr = User()
 
        new_usr.username = USER
 
        new_usr.password = get_crypt_password(PASS)
 
        new_usr.email = 'mail@mail.com'
 
        new_usr.name = 'test'
 
        new_usr.lastname = 'lasttestname'
 
        new_usr.active = True
 
        new_usr.admin = True
 
        sa.add(new_usr)
 
        sa.commit()
 

	
 
    print 'done'
 

	
 

	
 
def create_test_repo(force=True):
 
    print 'creating test repo'
 
    from kallithea.model.repo import RepoModel
 
    sa = get_session()
 

	
 
    user = sa.query(User).filter(User.username == USER).scalar()
 
    if user is None:
 
        raise Exception('user not found')
 

	
 
    repo = sa.query(Repository).filter(Repository.repo_name == HG_REPO).scalar()
 

	
 
    if repo is None:
 
        print 'repo not found creating'
 

	
 
        form_data = {'repo_name': HG_REPO,
 
                     'repo_type': 'hg',
 
                     'private':False,
 
                     'clone_uri': '' }
 
        rm = RepoModel(sa)
 
        rm.base_path = '/home/hg'
 
        rm.create(form_data, user)
 

	
 
    print 'done'
 

	
 

	
 
def set_anonymous_access(enable=True):
 
    sa = get_session()
 
    user = sa.query(User).filter(User.username == 'default').one()
 
    user.active = enable
 
    sa.add(user)
 
    sa.commit()
 

	
 

	
 
def get_anonymous_access():
 
    sa = get_session()
 
    return sa.query(User).filter(User.username == 'default').one().active
 

	
 

	
 
#==============================================================================
 
# TESTS
 
#==============================================================================
 
def test_clone_with_credentials(no_errors=False, repo=HG_REPO, method=METHOD,
 
                                seq=None, backend='hg'):
 
    cwd = path = jn(TESTS_TMP_PATH, repo)
 

	
 
    if seq is None:
 
        seq = _RandomNameSequence().next()
 

	
 
    try:
 
        shutil.rmtree(path, ignore_errors=True)
 
        os.makedirs(path)
 
        #print 'made dirs %s' % jn(path)
 
    except OSError:
 
        raise
 

	
 
    clone_url = 'http://%(user)s:%(pass)s@%(host)s/%(cloned_repo)s' % \
 
                  {'user': USER,
 
                   'pass': PASS,
 
                   'host': HOST,
 
                   'cloned_repo': repo, }
 

	
 
    dest = path + seq
 
    if method == 'pull':
 
        stdout, stderr = Command(cwd).execute(backend, method, '--cwd', dest, clone_url)
 
    else:
 
        stdout, stderr = Command(cwd).execute(backend, method, clone_url, dest)
 
        print stdout,'sdasdsadsa'
 
        if not no_errors:
 
            if backend == 'hg':
 
                assert """adding file changes""" in stdout, 'no messages about cloning'
 
                assert """abort""" not in stderr , 'got error from clone'
 
            elif backend == 'git':
 
                assert """Cloning into""" in stdout, 'no messages about cloning'
 

	
 
if __name__ == '__main__':
 
    try:
 
        create_test_user(force=False)
 
        seq = None
 
        import time
 

	
 
        try:
 
            METHOD = sys.argv[3]
 
        except Exception:
 
        except IndexError:
 
            pass
 

	
 
        try:
 
            backend = sys.argv[4]
 
        except Exception:
 
        except IndexError:
 
            backend = 'hg'
 

	
 
        if METHOD == 'pull':
 
            seq = _RandomNameSequence().next()
 
            test_clone_with_credentials(repo=sys.argv[1], method='clone',
 
                                        seq=seq, backend=backend)
 
        s = time.time()
 
        for i in range(1, int(sys.argv[2]) + 1):
 
            print 'take', i
 
            test_clone_with_credentials(repo=sys.argv[1], method=METHOD,
 
                                        seq=seq, backend=backend)
 
        print 'time taken %.3f' % (time.time() - s)
 
    except Exception, e:
 
        sys.exit('stop on %s' % e)
0 comments (0 inline, 0 general)