//-----------------------------------------------------------------------------
//
// SipMessage.cpp - Encapsulates SIP requests and responses.
//
//    Copyright (C) 2004  Mark D. Collier
//
//    This program is free software; you can redistribute it and/or modify
//    it under the terms of the GNU General Public License as published by
//    the Free Software Foundation; either version 2 of the License, or
//    (at your option) any later version.
//
//    This program is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU General Public License for more details.
//
//    You should have received a copy of the GNU General Public License
//    along with this program; if not, write to the Free Software
//    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
//   Author: Mark D. Collier   - 12/01/2006   v1.1
//                   Mark D. Collier   -  04/26/2004  v1.0
//         www.securelogix.com - mark.collier@securelogix.com
//         www.hackingexposedvoip.com

//-----------------------------------------------------------------------------

#include <arpa/inet.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "util.h"
#include "SipMessage.h"
#include "SipHeader.h"
#include "SipUri.h"

#define malloc  mymalloc
#define free    myfree
#define strdup  mystrdup
#define strndup mystrndup

void  SipMessage::constructorHelper( in_addr_t  aSourceAddress,
                                     in_port_t  aSourcePort,
                                     in_addr_t  aDestinationAddress,
                                     in_port_t  aDestinationPort,
                                     bool       aRequestFlag )
{
    mError              = NULL;
    mErrorInfo          = NULL;
    mSourceAddress      = aSourceAddress;
    mSourcePort         = aSourcePort;
    mDestinationAddress = aDestinationAddress;
    mDestinationPort    = aDestinationPort;
    mRequestFlag        = aRequestFlag;
    mMajorVersion       = 2;
    mMinorVersion       = 0;
    mHeaderListHead     = NULL;
    mHeaderListTail     = NULL;
    mContent            = NULL;
    mContentSize        = 0;
    mRequestMethod      = NULL;
    mRequestUri         = NULL;
    mResponseCode       = 0;
    mResponseText       = NULL;
}


// aPayloadSize may be -1 if aPayload is NULL terminated.
void  SipMessage::constructorHelper( in_addr_t  aSourceAddress,
                                     in_port_t  aSourcePort,
                                     in_addr_t  aDestinationAddress,
                                     in_port_t  aDestinationPort,
                                     char *     aPayload,
                                     int        aPayloadSize )
{
    char *  cp;
    char *  cp2;
    char *  cp3;

    // FIXME: This is the easy way out.
    if ( aPayloadSize >= 0 )
    {
        cp = ( char * )malloc( aPayloadSize + 1 );
        if ( !cp )
        {
            return;
        }
        memcpy( cp, aPayload, aPayloadSize );
        cp[aPayloadSize] = '\0';
        aPayload = cp;
        aPayloadSize = -1;
    }

    if ( strncmp( "SIP/", aPayload, 4 ) == 0 )
    {
        constructorHelper( aSourceAddress,
                           aSourcePort,
                           aDestinationAddress,
                           aDestinationPort,
                           0 );

        cp  = aPayload + 4;
        cp2 = strchr( cp, '.' );
        if ( !cp2 || cp2 == cp )
        {
            mError     = "Bad response version";
            mErrorInfo = cp;
            return;
        }
        mMajorVersion = strtol( cp, NULL, 10 );
        cp2++;

        cp = strchr( cp2, ' ' );
        if ( !cp || cp == cp2 )
        {
            mError     = "Bad response version";
            mErrorInfo = cp2;
            return;
        }
        mMinorVersion = strtol( cp2, NULL, 10 );
        cp++;

        cp2 = strchr( cp, ' ' );
        if ( !cp2 || cp2 == cp )
        {
            mError     = "Bad response status code";
            mErrorInfo = cp;
            return;
        }
        SetResponseCode( strtol( cp, NULL, 10 ) );
        cp2++;

        cp = strchr( cp2, '\n' );
        if ( !cp || cp == cp2 )
        {
            mError     = "Bad response status message";
            mErrorInfo = cp2;
            return;
        }
        cp3 = cp + 1;
        if ( cp[-1] == '\r' )
        {
            cp--;
        }
        SetResponseText( strndup( cp2, cp - cp2 ) );

        cp = addHeaders( cp3 );
        if ( !cp )
        {
            return;
        }

        SetContent( strdup( cp ) );
    }
    else
    {
        constructorHelper( aSourceAddress,
                           aSourcePort,
                           aDestinationAddress,
                           aDestinationPort,
                           1 );
        if ( mError )
        {
            return;
        }

        cp = strchr( aPayload, ' ' );
        if ( !cp || cp == aPayload )
        {
            mError     = "Bad message";
            mErrorInfo = aPayload;
            return;
        }
        SetRequestMethod( strndup( aPayload, cp - aPayload ) );
        cp++;

        cp2 = strchr( cp, '\n' );
        if ( !cp2 || cp2 == cp )
        {
            mError     = "Bad message";
            mErrorInfo = cp;
            return;
        }

        cp3 = strchr( cp, ' ' );
        if ( !cp3 )
        {
            mError     = "Bad request uri";
            mErrorInfo = cp;
            return;
        }
        SetRequestUri( new SipUri( strndup( cp, cp3 - cp ) ) );
        cp = cp3 + 1;

        if ( strncmp( "SIP/", cp, 4 ) != 0 )
        {
            mError     = "Bad request protocol";
            mErrorInfo = cp;
            return;
        }
        cp += 4;

        cp3 = strchr( cp, '.' );
        if ( !cp3 || cp3 == cp )
        {
            mError     = "Bad request version";
            mErrorInfo = cp;
            return;
        }
        mMajorVersion = strtol( cp, NULL, 10 );
        mMinorVersion = strtol( cp3 + 1, NULL, 10 );

        cp = addHeaders( cp2 + 1 );
        if ( !cp )
        {
            return;
        }

        SetContent( strdup( cp ) );
    }
}


SipMessage::SipMessage( bool aIsRequest )
{
    constructorHelper( inet_addr( "127.0.0.1" ),
                       htons( 5060 ),
                       inet_addr( "127.0.0.1" ),
                       htons( 5060 ),
                       aIsRequest );
}


SipMessage::SipMessage( in_addr_t  aSourceAddress,
                        in_port_t  aSourcePort,
                        in_addr_t  aDestinationAddress,
                        in_port_t  aDestinationPort,
                        char *     aPayload )
{
    constructorHelper( aSourceAddress,
                       aSourcePort,
                       aDestinationAddress,
                       aDestinationPort,
                       aPayload,
                       -1 );
}


SipMessage::SipMessage( in_addr_t  aSourceAddress,
                        in_port_t  aSourcePort,
                        in_addr_t  aDestinationAddress,
                        in_port_t  aDestinationPort,
                        char *     aPayload,
                        int        aPayloadSize )
{
    constructorHelper( aSourceAddress,
                       aSourcePort,
                       aDestinationAddress,
                       aDestinationPort,
                       aPayload,
                       aPayloadSize );
}


char *  SipMessage::addHeaders( char *  aCp )
{
    char *  name;
    char *  cp2;
    char *  cp3;
    char *  cp4;
    char *  cp5;
    char *  cp6;

    while ( *aCp && *aCp != '\r' && *aCp != '\n' )
    {
        cp2 = strchr( aCp, '\n' );
        if ( !cp2 )
        {
            cp2 = aCp + strlen( aCp ) - 1;
            cp3 = cp2;
        }
        else
        {
            cp3 = cp2;
            if ( cp2 > aCp && *( cp2 - 1 ) == '\r' )
            {
                cp2--;
            }
        }
        name = aCp;
        aCp = strchr( aCp, ':' );
        if ( aCp )
        {
            cp4 = aCp;
            aCp++;
            while ( *aCp && *aCp == ' ' )
            {
                aCp++;
            }
            cp5 = strndup( name, cp4 - name );
            cp6 = strndup( aCp, cp2 - aCp );
            AddHeader( new SipHeader( cp5, cp6 ) );
            free( cp5 );
            free( cp6 );
        }
        else
        {
            cp5 = strndup( name, cp4 - name );
            AddHeader( new SipHeader( cp5, "" ) );
            free( cp5 );
        }
        aCp = cp3 + 1;
    }
    while ( *aCp && ( *aCp == '\r' || *aCp == '\n' ) )
    {
        aCp++;
    }
    return aCp;
}


SipMessage::~SipMessage( void )
{
    headerNode *  node;
    headerNode *  node2;

    node = mHeaderListHead;
    while ( node )
    {
        node2 = node->Next;
        node->Next = NULL;
        delete node->Value;
        delete node;
        node = node2;
    }
    mHeaderListHead = NULL;
    mHeaderListTail = NULL;

    if ( mContent )
    {
        free( mContent );
        mContent     = NULL;
        mContentSize = 0;
    }
    if ( mRequestMethod )
    {
        free( mRequestMethod );
        mRequestMethod = NULL;
    }
    if ( mRequestUri )
    {
        delete mRequestUri;
        mRequestUri = NULL;
    }
    if ( mResponseText )
    {
        free( mResponseText );
        mResponseText = NULL;
    }
}


char *  SipMessage::GetError( void )
{
    return mError;
}


char *  SipMessage::GetErrorInfo( void )
{
    return mErrorInfo;
}


in_addr_t  SipMessage::GetSourceAddress( void )
{
    return mSourceAddress;
}


void  SipMessage::SetSourceAddress( in_addr_t  aAddress )
{
    mSourceAddress = aAddress;
}


in_addr_t  SipMessage::GetDestinationAddress( void )
{
    return mDestinationAddress;
}


void  SipMessage::SetDestinationAddress( in_addr_t  aAddress )
{
    mDestinationAddress = aAddress;
}


in_port_t  SipMessage::GetSourcePort( void )
{
    return mSourcePort;
}


void  SipMessage::SetSourcePort( in_port_t  aPort )
{
    mSourcePort = aPort;
}


in_port_t  SipMessage::GetDestinationPort( void )
{
    return mDestinationPort;
}


void  SipMessage::SetDestinationPort( in_port_t  aPort )
{
    mDestinationPort = aPort;
}


bool  SipMessage::IsRequest( void )
{
    return mRequestFlag;
}


bool  SipMessage::IsResponse( void )
{
    return !mRequestFlag;
}


int  SipMessage::GetMajorVersion( void )
{
    return mMajorVersion;
}


int  SipMessage::GetMinorVersion( void )
{
    return mMinorVersion;
}


char *  SipMessage::GetRequestMethod( void )
{
    if ( mRequestMethod )
    {
        return mRequestMethod;
    }
    return "";
}


void  SipMessage::SetRequestMethod( char *  aMethod )
{
    if ( mRequestMethod )
    {
        free( mRequestMethod );
    }
    mRequestMethod = aMethod;
}


SipUri *  SipMessage::GetRequestUri( void )
{
    if ( mRequestUri )
    {
        return mRequestUri;
    }
    return NULL;
}


// Note that the uri given will be deleted.
void  SipMessage::SetRequestUri( SipUri *  aUri )
{
    if ( mRequestUri )
    {
        delete mRequestUri;
    }
    mRequestUri = aUri;
}


int  SipMessage::GetResponseCode( void )
{
    return mResponseCode;
}


void  SipMessage::SetResponseCode( int  aCode )
{
    mResponseCode = aCode;
}


char *  SipMessage::GetResponseText( void )
{
    if ( mResponseText )
    {
        return mResponseText;
    }
    return "";
}


void  SipMessage::SetResponseText( char *  aText )
{
    if ( mResponseText )
    {
        free( mResponseText );
    }
    mResponseText = aText;
}


void  SipMessage::AddHeader( SipHeader *  aHeader )
{
    headerNode *  node;

    node = new headerNode();
    if ( !node )
    {
        mError = "Out of memory";
        return;
    }
    node->Next  = NULL;
    node->Value = aHeader;
    if ( mHeaderListTail )
    {
        mHeaderListTail->Next = node;
    }
    else
    {
        mHeaderListHead = node;
    }
    mHeaderListTail = node;
}


void  SipMessage::InsertHeader( SipHeader *  aHeader )
{
    headerNode *  node;

    node = new headerNode();
    if ( !node )
    {
        mError = "Out of memory";
        return;
    }
    node->Next  = NULL;
    node->Value = aHeader;
    node->Next = mHeaderListHead;
    mHeaderListHead = node;
}


bool  SipMessage::RemoveHeader( char *  aName )
{
    headerNode *  previousNode;
    headerNode *  node;

    previousNode = NULL;
    node         = mHeaderListHead;
    while ( node )
    {
        if ( strcasecmp( node->Value->GetName(), aName ) == 0 )
        {
            if ( previousNode )
            {
                previousNode->Next = node->Next;
            }
            else
            {
                mHeaderListHead = node->Next;
                if ( !mHeaderListHead )
                {
                    mHeaderListTail = NULL;
                }
            }
            delete node->Value;
            node->Value = NULL;
            node->Next  = NULL;
            delete node;
            return true;
        }
        previousNode = node;
        node         = node->Next;
    }

    return false;
}


SipHeader *  SipMessage::GetHeader( char *  aName )
{
    headerNode *  node;

    node = mHeaderListHead;
    while ( node )
    {
        if ( strcasecmp( node->Value->GetName(), aName ) == 0 )
        {
            return node->Value;
        }
        node = node->Next;
    }

    return NULL;
}


SipHeader *  SipMessage::GetHeader( int  aIndex )
{
    headerNode *  node;

    if ( aIndex >= 0 )
    {
        for ( node = mHeaderListHead; node; node = node->Next )
        {
            aIndex--;
            if ( aIndex < 0 )
            {
                break;
            }
        }
        if ( node )
        {
            return node->Value;
        }
    }

    return NULL;
}


int  SipMessage::GetHeaderCount( void )
{
    headerNode *  node;
    int           count;

    count = 0;
    for ( node = mHeaderListHead; node; node = node->Next )
    {
        count++;
    }

    return count;
}


char *  SipMessage::GetContent( void )
{
    if ( !mContent )
    {
        return "";
    }
    return mContent;
}


void  SipMessage::SetContent( char * aValue )
{
    if ( mContent )
    {
        free( mContent );
    }
    mContent     = aValue;
    mContentSize = strlen( mContent );
}


int  SipMessage::GetContentSize( void )
{
    return mContentSize;
}


char *  SipMessage::GetPacket( void )
{
    char *  packet;
    int     contentSize;
    int     headerCount;
    int     headerIndex;
    int     packetPosition;
    int     packetSize;

    headerCount    = GetHeaderCount();
    contentSize    = GetContentSize();
    packet         = NULL;
    packetPosition = 0;
    packetSize     = 0;

    while ( packetPosition >= packetSize )
    {
        if ( packet )
        {
            free( packet );
            packetSize *= 2;
        }
        else
        {
            packetSize = headerCount * 80 + contentSize;
        }
        packet = ( char * )malloc( packetSize );
        if ( !packet )
        {
            return NULL;
        }
        packetPosition = 0;

        if ( IsRequest() )
        {
            packetPosition += snprintf( packet + packetPosition,
                                        packetSize - packetPosition,
                                        "%s %s SIP/%d.%d\r\n",
                                        GetRequestMethod(),
                                        GetRequestUri()->GetFullText(),
                                        GetMajorVersion(),
                                        GetMinorVersion() );
        }
        else if ( IsResponse() )
        {
            packetPosition += snprintf( packet + packetPosition,
                                        packetSize - packetPosition,
                                        "SIP/%d.%d %d %s\r\n",
                                        GetMajorVersion(),
                                        GetMinorVersion(),
                                        GetResponseCode(),
                                        GetResponseText() );
        }

        if ( packetPosition >= packetSize )
        {
            continue;
        }

        for ( headerIndex = 0; headerIndex < headerCount; headerIndex++ )
        {
            packetPosition += snprintf(
                                      packet + packetPosition,
                                      packetSize - packetPosition,
                                      "%s\r\n",
                                      GetHeader( headerIndex )->GetFullText() );
            if ( packetPosition >= packetSize )
            {
                break;
            }
        }

        if ( packetPosition >= packetSize )
        {
            continue;
        }

        packetPosition += snprintf( packet + packetPosition,
                                    packetSize - packetPosition,
                                    "\r\n" );

        if ( contentSize && GetContent() )
        {
            if ( contentSize > packetSize - packetPosition )
            {
                continue;
            }
            memcpy( packet + packetPosition,
                    GetContent(),
                    contentSize );
            packetPosition += contentSize;
        }

        if ( packetPosition >= packetSize )
        {
            continue;
        }

        packet[packetPosition] = '\0';
    }

    return packet;
}

