//-----------------------------------------------------------------------------
//
// SipUdpPort.cpp - Handles receiving and queuing
//                  incoming messages and dequeueing and
//                  sending outgoing messages over UDP.
//
//    Copyright (C) 2004  Mark D. Collier
//
//    This program is free software; you can redistribute it and/or modify
//    it under the terms of the GNU General Public License as published by
//    the Free Software Foundation; either version 2 of the License, or
//    (at your option) any later version.
//
//    This program is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU General Public License for more details.
//
//    You should have received a copy of the GNU General Public License
//    along with this program; if not, write to the Free Software
//    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
//   Author: Mark D. Collier   - 12/01/2006   v1.1
//                   Mark D. Collier   -  04/26/2004  v1.0
//         www.securelogix.com - mark.collier@securelogix.com
//         www.hackingexposedvoip.com

//-----------------------------------------------------------------------------

#include <arpa/inet.h>
#include <errno.h>
#include <netinet/in.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/types.h>
#include <time.h>
#include <unistd.h>

#include "util.h"
#include "SipMessage.h"
#include "SipUdpPort.h"


#define MAX_PAYLOAD_LENGTH 65535
// #define MAX_MESSAGE_QUEUE_SIZE 3000

#define malloc  mymalloc
#define free    myfree
#define strdup  mystrdup
#define strndup mystrndup


SipUdpPort::SipUdpPort( void )
{
    mAddress          = INADDR_ANY;
    mPort             = htons( 5060 );
    mStopFlag         = false;
    mHoldState        = false;
    mCreationTime     = time( NULL );
    mMessagesSent     = 0;
    mMessagesReceived = 0;
    mBytesSent        = 0;
    mBytesReceived    = 0;
}


SipUdpPort::~SipUdpPort( void )
{
}


SipUdpPort::Error  SipUdpPort::Run( void )
{
    char *              payload;
    char *              packet;
    fd_set              socketFdSet;
    int                 payloadLength;
    int                 payloadSent;
    int                 selectStatus;
    int                 socketFd;
    int                 socketReceiveBufferSize;
    long int            timeoutUsec;
    SipMessage *        message;
    socklen_t           socketAddressLength;
    struct sockaddr_in  localAddress;
    struct sockaddr_in  remoteAddress;
    struct timeval      timeout;
    UdpSource *         udpSource;

    payload = ( char * )malloc( MAX_PAYLOAD_LENGTH + 1 );
    if ( !payload )
    {
        return ERROR_MEMORY;
    }

    socketFd = socket( PF_INET, SOCK_DGRAM, IPPROTO_UDP );
    if ( socketFd == -1 )
    {
        free( payload );
        return ERROR_SOCKET;
    }

    socketReceiveBufferSize = 32768;
    if ( setsockopt( socketFd,
                     SOL_SOCKET,
                     SO_RCVBUF,
                     &socketReceiveBufferSize,
                     sizeof( socketReceiveBufferSize ) ) == -1 )
    {
        free( payload );
        return ERROR_SOCKET;
    }

    localAddress.sin_family      = AF_INET;
    localAddress.sin_addr.s_addr = mAddress;
    localAddress.sin_port        = mPort;
    memset( &localAddress.sin_zero, '\0', 8 );

    if (   bind( socketFd,
                 ( struct sockaddr * )&localAddress,
                 sizeof( localAddress ) ) == -1 )
    {
        fprintf( stderr,
                 "%s:%d: Error binding to %d.%d.%d.%d:%d\n",
                 __FILE__,
                 __LINE__,
                 mAddress       & 0xff,
                 mAddress >>  8 & 0xff,
                 mAddress >> 16 & 0xff,
                 mAddress >> 24 & 0xff,
                 ntohs( mPort ) );
        free( payload );
        return ERROR_BIND;
    }

    payload[MAX_PAYLOAD_LENGTH] = '\0';
    mStopFlag                    = 0;
    timeoutUsec                 = 1000; // one millisecond

    while ( !mStopFlag )
    {
        FD_ZERO( &socketFdSet );
        FD_SET( socketFd, &socketFdSet );

        do
        {
            timeout.tv_sec  = 0;
            timeout.tv_usec = timeoutUsec;

            selectStatus = select( socketFd + 1,
                                   &socketFdSet,
                                   NULL,
                                   NULL,
                                   &timeout );

            if ( selectStatus == -1 )
            {
                free( payload );
                return ERROR_SELECT;
            }

            if ( selectStatus > 0 )
            {
                socketAddressLength = sizeof( remoteAddress );
                payloadLength = recvfrom( socketFd,
                                          payload,
                                          MAX_PAYLOAD_LENGTH,
                                          0,
                                          ( struct sockaddr * )&remoteAddress,
                                          &socketAddressLength );
                if ( payloadLength == -1 )
                {
                    // KLUDGE: This only seems to happen under Cygwin. Why? I
                    // have no idea. The error means "Connection forcibly
                    // closed by peer", but this is a connectionless UDP
                    // recvfrom (anyone!) being called.
                    if ( errno == ECONNRESET )
                    {
                        fprintf( stderr, "Cygwin kludge error correction.\n" );
                        close( socketFd );
                        socketFd = socket( PF_INET, SOCK_DGRAM, IPPROTO_UDP );
                        if ( socketFd == -1 )
                        {
                            free( payload );
                            return ERROR_SOCKET;
                        }
                        if (   bind( socketFd,
                                    ( struct sockaddr * )&localAddress,
                                    sizeof( localAddress ) ) == -1 )
                        {
                            fprintf( stderr,
                                     "%s:%d: Error rebinding to "
                                     "%d.%d.%d.%d:%d\n",
                                     __FILE__,
                                     __LINE__,
                                     mAddress       & 0xff,
                                     mAddress >>  8 & 0xff,
                                     mAddress >> 16 & 0xff,
                                     mAddress >> 24 & 0xff,
                                     ntohs( mPort ) );
                            free( payload );
                            return ERROR_BIND;
                        }
                        continue;
                    }
                    free( payload );
                    return ERROR_RECVFROM;
                }

                payload[payloadLength] = '\0';
                // PRINT_SENT_PACKETS printf( "<<<BEGIN %s:%d\n%s\n<<<END\n", inet_ntoa( remoteAddress.sin_addr ), ntohs( remoteAddress.sin_port ), payload );

                mMessagesReceived++;
                mBytesReceived += payloadLength;

                AddIncoming( new SipMessage( remoteAddress.sin_addr.s_addr,
                                             remoteAddress.sin_port,
                                             localAddress.sin_addr.s_addr,
                                             localAddress.sin_port,
                                             payload ) );
            }

            timeoutUsec = 1000;
        }
        while ( selectStatus > 0 && !mStopFlag );

        if ( mHoldState )
        {
            timeoutUsec = 1000;
        }
        else if ( ( message = GetNextOutgoing() ) != NULL )
        {
            packet                        = message->GetPacket();
            payloadLength                 = strlen( packet );
            remoteAddress.sin_family      = AF_INET;
            remoteAddress.sin_addr.s_addr = message->GetDestinationAddress();
            remoteAddress.sin_port        = message->GetDestinationPort();
            memset( &remoteAddress.sin_zero, '\0', 8 );
            // PRINT_SENT_PACKETS printf( ">>>BEGIN %s:%d\n%s\n>>>END\n", inet_ntoa( remoteAddress.sin_addr ), ntohs( remoteAddress.sin_port ), packet );
            payloadSent = sendto( socketFd,
                                  packet,
                                  payloadLength,
                                  0,
                                  ( struct sockaddr * )&remoteAddress,
                                  sizeof( remoteAddress ) );
            free( packet );
            delete message;
            if ( payloadSent < 0 )
            {
                free( payload );
                return ERROR_SENDTO;
            }
            else if ( payloadSent < payloadLength )
            {
                free( payload );
                return ERROR_SENDTO_ALL_BYTES_NOT_SENT;
            }
            mMessagesSent++;
            mBytesSent += payloadSent;

            timeoutUsec = 0;
        }
        else if (   ( udpSource = GetNextRespondingUdpSource() ) != NULL
                 && ( packet = udpSource->GetUdpPacket() ) != NULL )
        {
            payloadLength                 = strlen( packet );
            remoteAddress.sin_family      = AF_INET;
            remoteAddress.sin_addr.s_addr = udpSource->GetUdpAddress();
            remoteAddress.sin_port        = udpSource->GetUdpPort();
            memset( &remoteAddress.sin_zero, '\0', 8 );
            // PRINT_SENT_PACKETS printf( ">>>BEGIN %s:%d\n%s\n>>>END\n", inet_ntoa( remoteAddress.sin_addr ), ntohs( remoteAddress.sin_port ), packet );
            payloadSent = sendto( socketFd,
                                  packet,
                                  payloadLength,
                                  0,
                                  ( struct sockaddr * )&remoteAddress,
                                  sizeof( remoteAddress ) );
            free( packet );
            if ( payloadSent < 0 )
            {
                free( payload );
                return ERROR_SENDTO;
            }
            else if ( payloadSent < payloadLength )
            {
                free( payload );
                return ERROR_SENDTO_ALL_BYTES_NOT_SENT;
            }
            mMessagesSent++;
            mBytesSent += payloadSent;

            timeoutUsec = 0;
        }
        else if (   !GetIncomingQueueQueued()
                 && ( udpSource = GetNextRequestingUdpSource() ) != NULL
                 && ( packet = udpSource->GetUdpPacket() ) != NULL )
        {
            payloadLength                 = strlen( packet );
            remoteAddress.sin_family      = AF_INET;
            remoteAddress.sin_addr.s_addr = udpSource->GetUdpAddress();
            remoteAddress.sin_port        = udpSource->GetUdpPort();
            memset( &remoteAddress.sin_zero, '\0', 8 );
            // PRINT_SENT_PACKETS printf( ">>>BEGIN %s:%d\n%s\n>>>END\n", inet_ntoa( remoteAddress.sin_addr ), ntohs( remoteAddress.sin_port ), packet );
            payloadSent = sendto( socketFd,
                                  packet,
                                  payloadLength,
                                  0,
                                  ( struct sockaddr * )&remoteAddress,
                                  sizeof( remoteAddress ) );
            free( packet );
            if ( payloadSent < 0 )
            {
                free( payload );
                return ERROR_SENDTO;
            }
            else if ( payloadSent < payloadLength )
            {
                free( payload );
                return ERROR_SENDTO_ALL_BYTES_NOT_SENT;
            }
            mMessagesSent++;
            mBytesSent += payloadSent;

            timeoutUsec = 0;
        }
        else if (   !GetIncomingQueueQueued()
                 && ( udpSource = GetNextRetransmittingUdpSource() ) != NULL
                 && ( packet = udpSource->GetUdpPacket() ) != NULL )
        {
            payloadLength                 = strlen( packet );
            remoteAddress.sin_family      = AF_INET;
            remoteAddress.sin_addr.s_addr = udpSource->GetUdpAddress();
            remoteAddress.sin_port        = udpSource->GetUdpPort();
            memset( &remoteAddress.sin_zero, '\0', 8 );
            // PRINT_SENT_PACKETS printf( ">>>BEGIN %s:%d\n%s\n>>>END\n", inet_ntoa( remoteAddress.sin_addr ), ntohs( remoteAddress.sin_port ), packet );
            payloadSent = sendto( socketFd,
                                  packet,
                                  payloadLength,
                                  0,
                                  ( struct sockaddr * )&remoteAddress,
                                  sizeof( remoteAddress ) );
            free( packet );
            if ( payloadSent < 0 )
            {
                free( payload );
                return ERROR_SENDTO;
            }
            else if ( payloadSent < payloadLength )
            {
                free( payload );
                return ERROR_SENDTO_ALL_BYTES_NOT_SENT;
            }
            mMessagesSent++;
            mBytesSent += payloadSent;

            timeoutUsec = 0;
        }
        else
        {
            timeoutUsec = 1000;
        }
        sched_yield();
    }

    close( socketFd );

    free( payload );
    return ERROR_NONE;
}


void  SipUdpPort::Stop( void )
{
    mStopFlag = -1;
}


SipMessage *  SipUdpPort::GetNextIncoming( void )
{
    return mIncomingQueue.GetNext();
}


void  SipUdpPort::AddIncoming( SipMessage *  aMessage )
{
    mIncomingQueue.Add( aMessage );
}


SipMessage *  SipUdpPort::GetNextOutgoing( void )
{
    return mOutgoingQueue.GetNext();
}

void  SipUdpPort::AddOutgoing( SipMessage *  aMessage )
{
    mOutgoingQueue.Add( aMessage );
}


SipUdpPort::UdpSource *  SipUdpPort::GetNextRespondingUdpSource( void )
{
    return mRespondingUdpSourceQueue.GetNext();
}


void  SipUdpPort::AddRespondingUdpSource( UdpSource *  aSource )
{
    mRespondingUdpSourceQueue.Add( aSource );
}


void  SipUdpPort::RemoveRespondingUdpSource( UdpSource *  aSource )
{
    mRespondingUdpSourceQueue.Remove( aSource );
}


SipUdpPort::UdpSource *  SipUdpPort::GetNextRequestingUdpSource( void )
{
    return mRequestingUdpSourceQueue.GetNext();
}


void  SipUdpPort::AddRequestingUdpSource( UdpSource *  aSource )
{
    mRequestingUdpSourceQueue.Add( aSource );
}


void  SipUdpPort::RemoveRequestingUdpSource( UdpSource *  aSource )
{
    mRequestingUdpSourceQueue.Remove( aSource );
}


SipUdpPort::UdpSource *  SipUdpPort::GetNextRetransmittingUdpSource( void )
{
    return mRetransmittingUdpSourceQueue.GetNext();
}


void  SipUdpPort::AddRetransmittingUdpSource( UdpSource *  aSource )
{
    mRetransmittingUdpSourceQueue.Add( aSource );
}


void  SipUdpPort::RemoveRetransmittingUdpSource( UdpSource *  aSource )
{
    mRetransmittingUdpSourceQueue.Remove( aSource );
}


time_t  SipUdpPort::GetCreationTime( void )
{
    return mCreationTime;
}


in_addr_t  SipUdpPort::GetBoundAddress( void )
{
    return mAddress;
}


void  SipUdpPort::SetBoundAddress( in_addr_t  aValue )
{
    mAddress = aValue;
}


in_port_t  SipUdpPort::GetBoundPort( void )
{
    return mPort;
}


void  SipUdpPort::SetBoundPort( in_port_t  aValue )
{
    mPort = aValue;
}


unsigned long  SipUdpPort::GetMessagesSent( void )
{
    return mMessagesSent;
}


unsigned long  SipUdpPort::GetMessagesReceived( void )
{
    return mMessagesReceived;
}


unsigned int  SipUdpPort::GetOutgoingQueueQueued( void )
{
    return mOutgoingQueue.GetQueuedCount();
}


unsigned int  SipUdpPort::GetIncomingQueueQueued( void )
{
    return mIncomingQueue.GetQueuedCount();
}


unsigned int  SipUdpPort::GetOutgoingQueueMostQueued( void )
{
    return mOutgoingQueue.GetMostQueuedCount();
}


unsigned int  SipUdpPort::GetIncomingQueueMostQueued( void )
{
    return mIncomingQueue.GetMostQueuedCount();
}


unsigned int  SipUdpPort::GetRespondingUdpSourceQueueQueued( void )
{
    return mRespondingUdpSourceQueue.GetQueuedCount();
}


unsigned int  SipUdpPort::GetRespondingUdpSourceQueueMostQueued( void )
{
    return mRespondingUdpSourceQueue.GetMostQueuedCount();
}


unsigned int  SipUdpPort::GetRequestingUdpSourceQueueQueued( void )
{
    return mRequestingUdpSourceQueue.GetQueuedCount();
}


unsigned int  SipUdpPort::GetRequestingUdpSourceQueueMostQueued( void )
{
    return mRequestingUdpSourceQueue.GetMostQueuedCount();
}


unsigned int  SipUdpPort::GetRetransmittingUdpSourceQueueQueued( void )
{
    return mRetransmittingUdpSourceQueue.GetQueuedCount();
}


unsigned int  SipUdpPort::GetRetransmittingUdpSourceQueueMostQueued( void )
{
    return mRetransmittingUdpSourceQueue.GetMostQueuedCount();
}


unsigned long long  SipUdpPort::GetBytesSent( void )
{
    return mBytesSent;
}


unsigned long long  SipUdpPort::GetBytesReceived( void )
{
    return mBytesReceived;
}


void  SipUdpPort::SetHoldState( bool  aValue )
{
    mHoldState = aValue;
}


SipUdpPort::messageQueue::messageQueue( void )
{
    pthread_mutex_init( &mMutex, NULL );
    mHead       = NULL;
    mTail       = NULL;
    mQueued     = 0;
    mMostQueued = 0;
}


SipUdpPort::messageQueue::~messageQueue( void )
{
    SipMessage *  message;
    while ( ( message = GetNext() ) != NULL )
    {
        delete message;
    }
    pthread_mutex_destroy( &mMutex );
}


SipMessage *  SipUdpPort::messageQueue::GetNext( void )
{
    SipMessage *   message;
    messageNode *  node;

    pthread_mutex_lock( &mMutex );
    node = mHead;
    if ( node )
    {
        mHead = node->Next;
        node->Next = NULL;
        if ( !mHead )
        {
            mTail = NULL;
        }
    }
    pthread_mutex_unlock( &mMutex );

    if ( node )
    {
        message = node->Message;;
        delete node;
        mQueued--;
        return message;
    }

    return NULL;
}


void  SipUdpPort::messageQueue::Add( SipMessage *  aMessage )
{
    messageNode *  node;
#ifdef MAX_MESSAGE_QUEUE_SIZE
    messageNode *  node2;
#endif

    node = new messageNode();
    node->Next    = NULL;
    node->Message = aMessage;

    pthread_mutex_lock( &mMutex );
    if ( mTail )
    {
        mTail->Next = node;
        mTail       = node;
    }
    else
    {
        mHead = mTail = node;
    }
    mQueued++;
#ifdef MAX_MESSAGE_QUEUE_SIZE
    node  = NULL;
    node2 = NULL;
    while ( mQueued > MAX_MESSAGE_QUEUE_SIZE )
    {
        if ( node2 )
        {
            node2->Next = mHead;
            node2       = mHead;
        }
        else
        {
            node = node2 = mHead;
        }
        mHead       = node2->Next;
        node2->Next = NULL;
        mQueued--;
    }
#endif
    if ( mQueued > mMostQueued )
    {
        mMostQueued = mQueued;
    }
    pthread_mutex_unlock( &mMutex );

#ifdef MAX_MESSAGE_QUEUE_SIZE
    while ( node )
    {
        node2      = node->Next;
        node->Next = NULL;
        delete node->Message;
        delete node;
        node = node2;
    }
#endif
}


unsigned int  SipUdpPort::messageQueue::GetQueuedCount( void )
{
    return mQueued;
}


unsigned int  SipUdpPort::messageQueue::GetMostQueuedCount( void )
{
    return mMostQueued;
}


SipUdpPort::sourceQueue::sourceQueue( void )
{
    pthread_mutex_init( &mMutex, NULL );
    mHead       = NULL;
    mTail       = NULL;
    mQueued     = 0;
    mMostQueued = 0;
}


SipUdpPort::sourceQueue::~sourceQueue( void )
{
    UdpSource *  source;
    while ( ( source = GetNext() ) != NULL )
    {
        // NOOP
    }
    pthread_mutex_destroy( &mMutex );
}


SipUdpPort::UdpSource *  SipUdpPort::sourceQueue::GetNext( void )
{
    UdpSource *   source;
    sourceNode *  node;

    pthread_mutex_lock( &mMutex );
    node = mHead;
    if ( node )
    {
        mHead = node->Next;
        node->Next = NULL;
        if ( !mHead )
        {
            mTail = NULL;
        }
    }
    pthread_mutex_unlock( &mMutex );

    if ( node )
    {
        source = node->Source;;
        delete node;
        mQueued--;
        return source;
    }

    return NULL;
}


void  SipUdpPort::sourceQueue::Add( UdpSource *  aSource )
{
    sourceNode *  node;
    sourceNode *  scan;
    sourceNode *  scanPrevious;

    node = new sourceNode();
    node->Next    = NULL;
    node->Source = aSource;

    pthread_mutex_lock( &mMutex );
    scanPrevious = NULL;
    scan         = mHead;
    while ( scan )
    {
        if ( scan->Source == aSource )
        {
            if ( scanPrevious )
            {
                scanPrevious->Next = scan->Next;
                delete scan;
                scan = NULL;
            }
            else
            {
                mHead = scan->Next;
                delete scan;
                scan = NULL;
            }
            mQueued--;
        }
        else
        {
            scanPrevious = scan;
            scan         = scan->Next;
        }
    }
    if ( mTail )
    {
        mTail->Next = node;
        mTail       = node;
    }
    else
    {
        mHead = mTail = node;
    }
    mQueued++;
    if ( mQueued > mMostQueued )
    {
        mMostQueued = mQueued;
    }
    pthread_mutex_unlock( &mMutex );
}


void  SipUdpPort::sourceQueue::Remove( UdpSource *  aSource )
{
    sourceNode *  scan;
    sourceNode *  scanPrevious;

    pthread_mutex_lock( &mMutex );
    scanPrevious = NULL;
    scan         = mHead;
    while ( scan )
    {
        if ( scan->Source == aSource )
        {
            if ( scanPrevious )
            {
                scanPrevious->Next = scan->Next;
                delete scan;
                scan = NULL;
            }
            else
            {
                mHead = scan->Next;
                delete scan;
                scan = NULL;
            }
            mQueued--;
        }
        else
        {
            scanPrevious = scan;
            scan         = scan->Next;
        }
    }
    pthread_mutex_unlock( &mMutex );
}


unsigned int  SipUdpPort::sourceQueue::GetQueuedCount( void )
{
    return mQueued;
}


unsigned int  SipUdpPort::sourceQueue::GetMostQueuedCount( void )
{
    return mMostQueued;
}

