import unittest
import shutil
import tempfile
import os.path
import sqlite3
import createrepo_c as cr

from .fixtures import *

class TestCaseSqlite(unittest.TestCase):

    def setUp(self):
        self.tmpdir = tempfile.mkdtemp(prefix="createrepo_ctest-")

    def tearDown(self):
        shutil.rmtree(self.tmpdir)

    def test_sqlite_basic_operations(self):
        db_pri = cr.Sqlite(self.tmpdir+"/primary.db", cr.DB_PRIMARY)
        self.assertTrue(db_pri)
        self.assertTrue(os.path.isfile(self.tmpdir+"/primary.db"))

        db_pri = cr.PrimarySqlite(self.tmpdir+"/primary2.db")
        self.assertTrue(db_pri)
        self.assertTrue(os.path.isfile(self.tmpdir+"/primary2.db"))

        db_fil = cr.Sqlite(self.tmpdir+"/filelists.db", cr.DB_FILELISTS)
        self.assertTrue(db_fil)
        self.assertTrue(os.path.isfile(self.tmpdir+"/filelists.db"))

        db_fil = cr.FilelistsSqlite(self.tmpdir+"/filelists2.db")
        self.assertTrue(db_fil)
        self.assertTrue(os.path.isfile(self.tmpdir+"/filelists2.db"))

        db_oth = cr.Sqlite(self.tmpdir+"/other.db", cr.DB_OTHER)
        self.assertTrue(db_oth)
        self.assertTrue(os.path.isfile(self.tmpdir+"/other.db"))

        db_oth = cr.OtherSqlite(self.tmpdir+"/other2.db")
        self.assertTrue(db_oth)
        self.assertTrue(os.path.isfile(self.tmpdir+"/other2.db"))


    def test_sqlite_error_cases(self):
        self.assertRaises(cr.CreaterepoCError, cr.Sqlite, self.tmpdir, cr.DB_PRIMARY)
        self.assertRaises(ValueError, cr.Sqlite, self.tmpdir+"/foo.db", 55)
        self.assertRaises(TypeError, cr.Sqlite, self.tmpdir+"/foo.db", None)
        self.assertRaises(TypeError, cr.Sqlite, None, cr.DB_PRIMARY)

    def test_sqlite_operations_on_closed_db(self):
        pkg = cr.package_from_rpm(PKG_ARCHER_PATH)
        path = os.path.join(self.tmpdir, "primary.db")
        db = cr.Sqlite(path, cr.DB_PRIMARY)
        self.assertTrue(db)
        db.close()

        self.assertRaises(cr.CreaterepoCError, db.add_pkg, pkg)
        self.assertRaises(cr.CreaterepoCError, db.dbinfo_update, "somechecksum")

        db.close()  # No error shoud be raised
        del db      # No error shoud be raised

    def test_sqlite_primary_schema(self):
        path = os.path.join(self.tmpdir, "primary.db")
        cr.PrimarySqlite(path)
        self.assertTrue(os.path.isfile(path))

        con = sqlite3.connect(path)
        # Check tables
        self.assertEqual(con.execute("""select name from sqlite_master where type="table";""").fetchall(),
            [('db_info',),
             ('packages',),
             ('files',),
             ('requires',),
             ('provides',),
             ('conflicts',),
             ('obsoletes',),
             ('suggests',),
             ('enhances',),
             ('recommends',),
             ('supplements',),
            ])
        # Check indexes
        self.assertEqual(con.execute("""select name from sqlite_master where type="index";""").fetchall(),
            [('packagename',),
             ('packageId',),
             ('filenames',),
             ('pkgfiles',),
             ('pkgrequires',),
             ('requiresname',),
             ('pkgprovides',),
             ('providesname',),
             ('pkgconflicts',),
             ('pkgobsoletes',),
             ('pkgsuggests',),
             ('pkgenhances',),
             ('pkgrecommends',),
             ('pkgsupplements',),
            ])
        # Check triggers
        self.assertEqual(con.execute("""select name from sqlite_master where type="trigger";""").fetchall(),
            [('removals',)])

    def test_sqlite_filelists_schema(self):
        path = os.path.join(self.tmpdir, "filelists.db")
        cr.FilelistsSqlite(path)
        self.assertTrue(os.path.isfile(path))

        con = sqlite3.connect(path)
        # Check tables
        self.assertEqual(con.execute("""select name from sqlite_master where type="table";""").fetchall(),
            [('db_info',), ('packages',), ('filelist',)])
        # Check indexes
        self.assertEqual(con.execute("""select name from sqlite_master where type="index";""").fetchall(),
            [('keyfile',), ('pkgId',), ('dirnames',)])
        # Check triggers
        self.assertEqual(con.execute("""select name from sqlite_master where type="trigger";""").fetchall(),
            [('remove_filelist',)])

    def test_sqlite_other_schema(self):
        path = os.path.join(self.tmpdir, "other.db")
        cr.OtherSqlite(path)
        self.assertTrue(os.path.isfile(path))

        con = sqlite3.connect(path)
        # Check tables
        self.assertEqual(con.execute("""select name from sqlite_master where type="table";""").fetchall(),
            [('db_info',), ('packages',), ('changelog',)])
        # Check indexes
        self.assertEqual(con.execute("""select name from sqlite_master where type="index";""").fetchall(),
            [('keychange',), ('pkgId',)])
        # Check triggers
        self.assertEqual(con.execute("""select name from sqlite_master where type="trigger";""").fetchall(),
            [('remove_changelogs',)])

    def test_sqlite_primary(self):
        path = os.path.join(self.tmpdir, "primary.db")
        db = cr.Sqlite(path, cr.DB_PRIMARY)
        pkg = cr.package_from_rpm(PKG_ARCHER_PATH)
        db.add_pkg(pkg)
        self.assertRaises(TypeError, db.add_pkg, None)
        self.assertRaises(TypeError, db.add_pkg, 123)
        self.assertRaises(TypeError, db.add_pkg, "foo")
        db.dbinfo_update("somechecksum")
        self.assertRaises(TypeError, db.dbinfo_update, pkg)
        self.assertRaises(TypeError, db.dbinfo_update, None)
        self.assertRaises(TypeError, db.dbinfo_update, 123)
        db.close()

        self.assertTrue(os.path.isfile(path))

        con = sqlite3.connect(path)

        # Check packages table
        res = con.execute("select * from packages").fetchall()
        self.assertEqual(res,
            [(1, '4e0b775220c67f0f2c1fd2177e626b9c863a098130224ff09778ede25cea9a9e',
              'Archer', 'x86_64', '3.4.5', '2', '6', 'Complex package.',
              'Archer package', 'http://soo_complex_package.eu/',
              res[0][10], 1365416480, 'GPL', 'ISIS', 'Development/Tools',
              'localhost.localdomain', 'Archer-3.4.5-6.src.rpm', 280, 2865,
              'Sterling Archer', 3101, 0, 544, None, None, 'sha256')])

        # Check provides table
        self.assertEqual(con.execute("select * from provides").fetchall(),
            [('bara', 'LE', '0', '22', None, 1),
             ('barb', 'GE', '0', '11.22.33', '44', 1),
             ('barc', 'EQ', '0', '33', None, 1),
             ('bard', 'LT', '0', '44', None, 1),
             ('bare', 'GT', '0', '55', None, 1),
             ('Archer', 'EQ', '2', '3.4.5', '6', 1),
             ('Archer(x86-64)', 'EQ', '2', '3.4.5', '6', 1)])

        # Check conflicts table
        self.assertEqual(con.execute("select * from conflicts").fetchall(),
            [('bba', 'LE', '0', '2222', None, 1),
             ('bbb', 'GE', '0', '1111.2222.3333', '4444', 1),
             ('bbc', 'EQ', '0', '3333', None, 1),
             ('bbd', 'LT', '0', '4444', None, 1),
             ('bbe', 'GT', '0', '5555', None, 1)])

        # Check obsoletes table
        self.assertEqual(con.execute("select * from obsoletes").fetchall(),
           [('aaa', 'LE', '0', '222', None, 1),
            ('aab', 'GE', '0', '111.2.3', '4', 1),
            ('aac', 'EQ', '0', '333', None, 1),
            ('aad', 'LT', '0', '444', None, 1),
            ('aae', 'GT', '0', '555', None, 1)])

        # Check requires table
        self.assertEqual(con.execute("select * from requires").fetchall(),
            [('fooa', 'LE', '0', '2', None, 1, 'FALSE'),
             ('foob', 'GE', '0', '1.0.0', '1', 1, 'FALSE'),
             ('fooc', 'EQ', '0', '3', None, 1, 'FALSE'),
             ('food', 'LT', '0', '4', None, 1, 'FALSE'),
             ('fooe', 'GT', '0', '5', None, 1, 'FALSE'),
             ('foof', 'EQ', '0', '6', None, 1, 'TRUE')])

        # Check files table
        self.assertEqual(con.execute("select * from files").fetchall(),
            [('/usr/bin/complex_a', 'file', 1)])

        # Check db_info table
        self.assertEqual(con.execute("select * from db_info").fetchall(),
            [(10, 'somechecksum')])

    def test_sqlite_filelists(self):
        path = os.path.join(self.tmpdir, "filelists.db")
        db = cr.Sqlite(path, cr.DB_FILELISTS)
        pkg = cr.package_from_rpm(PKG_ARCHER_PATH)
        db.add_pkg(pkg)
        db.dbinfo_update("somechecksum2")
        db.close()

        self.assertTrue(os.path.isfile(path))

        con = sqlite3.connect(path)

        # Check packages table
        self.assertEqual(con.execute("select * from packages").fetchall(),
            [(1, '4e0b775220c67f0f2c1fd2177e626b9c863a098130224ff09778ede25cea9a9e')])

        # Check files table
        self.assertEqual(set(con.execute("select * from filelist").fetchall()),
            set([(1, '/usr/share/doc', 'Archer-3.4.5', 'd'),
             (1, '/usr/bin', 'complex_a', 'f'),
             (1, '/usr/share/doc/Archer-3.4.5', 'README', 'f')]))

        # Check db_info table
        self.assertEqual(con.execute("select * from db_info").fetchall(),
            [(10, 'somechecksum2')])

    def test_sqlite_other(self):
        path = os.path.join(self.tmpdir, "other.db")
        db = cr.Sqlite(path, cr.DB_FILELISTS)
        pkg = cr.package_from_rpm(PKG_ARCHER_PATH)
        db.add_pkg(pkg)
        db.dbinfo_update("somechecksum3")
        db.close()

        self.assertTrue(os.path.isfile(path))

        con = sqlite3.connect(path)

        # Check packages table
        self.assertEqual(con.execute("select * from packages").fetchall(),
            [(1, '4e0b775220c67f0f2c1fd2177e626b9c863a098130224ff09778ede25cea9a9e')])

        # Check filelist table
        self.assertEqual(set(con.execute("select * from filelist").fetchall()),
            set([(1, '/usr/share/doc', 'Archer-3.4.5', 'd'),
             (1, '/usr/bin', 'complex_a', 'f'),
             (1, '/usr/share/doc/Archer-3.4.5', 'README', 'f')]))

        # Check db_info table
        self.assertEqual(con.execute("select * from db_info").fetchall(),
            [(10, 'somechecksum3')])
