
import md5
import struct
import sys

from twisted.internet import protocol, reactor

# All messages begin with an ASCII character, followed by a 32 bit length,
# except for CancelRequest, SSLRequest and StartupMessage, which are sent at
# the beginning of a connection.

class PostgresProtocol(protocol.Protocol):
    protocol_version = 3 << 16

    def dataReceived(self, data):
        self.buffer += data
        self.dispatch()

    def dispatchOne(self, type, contents):
        method = getattr(self, 'handle_' + type, None)

        if method is not None:
            method(contents)
        else:
            self.handle_unknown(type, contents)

    def dispatch(self):
        while len(self.buffer) >= 5:
            type, length, = struct.unpack('>cI', self.buffer[:5])

            if len(self.buffer) < length + 1:
                return

            contents = self.buffer[5 : length + 1]
            self.dispatchOne(type, contents)
            self.buffer = self.buffer[length + 1 :]

    def packMessage(self, type, contents):
        return struct.pack('>cI', type, len(contents) + 4) + contents

    def sendMessage(self, type, contents):
        m = self.packMessage(type, contents)
        #print 'send: %r' % m
        self.transport.write(m)

    def handle_unknown(self, type, contents):
        print 'ignoring message: %r' % ((type, contents),)

class PostgresServer(PostgresProtocol):
    def __init__(self):
        self.buffer = ''
        self.dispatch = self.startup
        self.parameters = {}

    def dataReceived(self, data):
        self.buffer += data
        self.dispatch()

    def connectionMade(self):
        print 'connection made'

    def connectionLost(self, reason):
        print 'connection lost: %r' % reason

    def sendAuthenticationOk(self):
        self.sendMessage('R', struct.pack('>I', 0))

    def sendNotificationResponse(self, condition):
        pid = 1234
        self.sendMessage('A', struct.pack('>I', pid) + condition + '\0\0')

    def sendReadyForQuery(self):
        self.sendMessage('Z', 'I')

    def startup(self):
        if len(self.buffer) < 8:
            return

        length, protocol_version = struct.unpack('>II', self.buffer[:8])

        if len(self.buffer) < length:
            return

        if protocol_version == (1234 << 16) + 5679:
            # Reject SSL request.
            self.transport.write('N')
            self.buffer = self.buffer[length:]
            return self.startup()

        assert protocol_version == self.protocol_version
        self.buffer = self.buffer[8:-1]

        while self.buffer:
            name, value, self.buffer = self.buffer.split('\0', 2)
            self.parameters[name] = value

        self.sendAuthenticationOk()
        self.sendReadyForQuery()
        self.dispatch = lambda: PostgresProtocol.dispatch(self)

    def handle_Q(self, contents):
        # Query
        self.sendReadyForQuery()
        self.sendNotificationResponse('foo')

class PostgresClient(PostgresProtocol):
    def __init__(self, username, password=None, database=None):
        self.buffer = ''
        self.username = username
        self.password = password
        self.database = database

    def connectionMade(self):
        print 'connection made'
        contents = (
            struct.pack('>I', self.protocol_version) +
            'user\0' + self.username + '\0' +
            'database\0' + 'postgres' + '\0\0')
        msg = (
            struct.pack('>I', len(contents) + 4) +
            contents)
        self.transport.write(msg)

    def connectionLost(self, reason):
        print 'connection lost'

    def handle_E(self, msg):
        msg = msg[:-1]
        e = {}

        while msg:
            type = msg[0]
            str, msg = msg[1:].split('\0', 1)
            e[type] = str

        raise RuntimeError(e)

    def handle_K(self, msg):
        pass

    def handle_R(self, msg):
        type, = struct.unpack('>I', msg[:4])

        if type == 0:
            # AuthenticationOk
            self.authenticated()
        elif type == 5:
            # AuthenticationMD5Password

            def dig(s):
                return md5.md5(s).hexdigest()

            salt, = struct.unpack('4s', msg[4:])
            password = 'md5' + dig(dig(self.password + self.username) + salt)
            assert len(password) == 35
            self.sendMessage('p', password)
        else:
            raise RuntimeError('unkown authentication message type %d', type)

    def handle_S(self, msg):
        name, value = msg[:-1].split('\0')
        print 'S %r' % ((name, value),)

    def handle_Z(self, msg):
        pass

    def authenticated(self):
        pass

class PostgresServerFactory(protocol.ServerFactory):
    protocol = PostgresServer

class PostgresClientFactory(protocol.ClientFactory):
    def __init__(self, *params):
        self.protocol = lambda: PostgresClient(*params)

def main((_, user, password, database)):
    reactor.connectTCP('localhost', 5432,
        PostgresClientFactory(user, password, database))
    reactor.listenTCP(5433, PostgresServerFactory())
    reactor.run()

if __name__ == '__main__':
    main(sys.argv)

