#ifndef _H_SIMPLESOCK_
#define _H_SIMPLESOCK_

#if (defined(_WIN32) && !defined(WIN32))
#define WIN32
#endif

#include <iostream>
#include <string>
#include <streambuf>
#include <exception>
#include <cstdio>
#include <cctype>

#ifdef WIN32
#define WIN32_MEAN_AND_LEAN
#include <winsock2.h>
#include <ws2tcpip.h>
#include <windows.h>
#else
#include <netdb.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#endif

#ifndef WIN32
#define SOCKET int
#define INVALID_SOCKET -1
#define SOCKET_ERROR -1
#endif

namespace simplesock
{
    // Thrown if winsock can't be initialized, this can ONLY be thrown in Windows
    // If error_code() is zero, it means winsock couldn't be initialized because of required version couldn't be found
    // If error_code() is non-zero, it means that WSAStartup() failed, and returns the error code from it, see MSDN for explanation of the error codes
    class init_fail : public std::exception
    {
        private:
            std::string _what;
            int _err;

        public:
            init_fail(const char *what, int err = 0)
                : _what(what),
                  _err(err)
            {
            }

            virtual ~init_fail() throw()
            {
            }

            virtual const char* what() const throw()
            {
                return _what.c_str();
            }

            int error() const throw()
            {
                return _err;
            }
    };

    // A socket buffer stream derived from streambuf which is a specialization for char traits
    class sockbuf : public std::streambuf
    {
        protected:
            static const int buffer_size = 8192;
            static const int putback_size = 1;
            SOCKET sock;
            char buffer_out[buffer_size];
            char buffer_in[buffer_size+putback_size];

        public:
            sockbuf()
                : sock(INVALID_SOCKET)
            {
#ifdef WIN32
                // We want any winsock 2 version
                WSADATA data;
                int err =  WSAStartup(MAKEWORD(2, 0), &data);

                // If there was an error, throw unable_to_init exception
                if (err != 0)
                {
                    throw init_fail("unable to initialize winsock", err);
                }

                // Check if we got the requested major version, if not throw wrong_version exception
                if (LOBYTE(data.wVersion) != 2)
                {
                    // If version not available throw wrong_version exception
                    throw init_fail("required version not available");
                }
#endif
                // We decrement the size of buffer by one, so we can insert one more
                // char when overflow() is called, this makes it easier to handle
                setp(buffer_out, buffer_out + sizeof(buffer_out) - 1);
                // The first character in the buffer is
                setg(buffer_in + 1, buffer_in + 1, buffer_in + 1);
            }

            virtual ~sockbuf()
            {
                close();
#ifdef WIN32
                WSACleanup();
#endif
            }

            sockbuf *open(const char *host, unsigned short port)
            {
                // Close old socket first
                if (sock != INVALID_SOCKET)
                    close();

                sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
                if (sock == INVALID_SOCKET)
                {
                    return NULL;
                }

                char port_buf[6];
#pragma warning(push)
#pragma warning(disable: 4996)
                std::sprintf(port_buf, "%hu", port);
#pragma warning(pop)

                addrinfo hints, *result;
                memset(&hints, 0, sizeof(hints));
                hints.ai_family = AF_INET;
                hints.ai_socktype = SOCK_STREAM;
                hints.ai_protocol = IPPROTO_TCP;

                if (getaddrinfo(host, port_buf, &hints, &result) != 0)
                {
                    return NULL;
                }

                sockaddr addr = *result[0].ai_addr;
                freeaddrinfo(result);

                if (::connect(sock, &addr, sizeof(addr)) != 0)
                {
                    close();
                    return NULL;
                }

                return this;
            }

            // Does not throw! Called by destructor
            sockbuf *close()
            {
                if (sock != INVALID_SOCKET)
                {
                    int err;
#ifdef WIN32
                    err = closesocket(sock);
#else
                    err = ::close(sock);
#endif
                    sock = INVALID_SOCKET;
                    if (err == SOCKET_ERROR)
                    {
                        return NULL;
                    }
                }

                return this;
            }

            bool is_open()
            {
                return (sock != INVALID_SOCKET);
            }

            // Simple resolve func, always returns first found address
            // Returns an empty string if can't resolve
            std::string resolve(const char *host)
            {
                // Stores our resolve results
                hostent *result;
                // Are we doing normal or reverse lookup
                int normal_lookup = (std::isalpha(host[0]) != 0);

                // If host starts with alphabetic character, assume it is host; Do normal DNS lookup
                if (normal_lookup)
                {
                    result = gethostbyname(host);
                }
                // Else let's to reverse DNS lookup
                else
                {
                    unsigned int addr = inet_addr(host);
                    result = gethostbyaddr(reinterpret_cast<char*>(&addr), 4, AF_INET);
                }

                if (result == NULL)
                {
                    return "";
                }

                // If we did a normal look up, let's return the IP
                if (normal_lookup)
                {
                    in_addr ip;
                    ip.s_addr = *(unsigned long*)result->h_addr_list[0];
                    return inet_ntoa(ip);
                }
                // Otherwise return associated the hostname
                else
                {
                    return result->h_name;
                }
            }

        protected:
            // This is called when the output buffer is full
            // It simply flushes the buffer
            virtual int_type overflow (int_type c)
            {
                // When overflow is called, we still have one more char left in the buffer
                if (c != EOF)
                {
                    *pptr() = c;
                    pbump(1);
                }

                // Flush the existing buffer
                if (sync() == -1)
                {
                     return EOF;
                }

                // Return the character that was outputted
                return c;
            }

            // This is called when input buffer needs more data
            // It first moves n amount of current characters to putback area
            // and then reads new chars from stream
            virtual int_type underflow()
            {
                if (gptr() < egptr())
                {
                    return *gptr();
                }

                int size = gptr() - eback();
                if (size > putback_size)
                {
                    size = putback_size;
                }

                // Copy current characters to putback area
                std::memcpy(buffer_in + (putback_size - size), gptr() - size, size);
                // Then read the new chars
                int num = recv(sock, buffer_in + putback_size, buffer_size, 0);
                // If error or if remote socket closed connection
                if (num <= 0)
                {
                    return EOF;
                }

                // Reset the input pointers
                setg(buffer_in + (putback_size - size), buffer_in + putback_size, buffer_in + putback_size + num);

                // Return the next char
                return *gptr();
            }

            virtual int_type pbackfail(int_type c)
            {
                // If there's space in the input buffer and the character check is valid
                if (gptr() > eback() && (c == EOF || c == gptr()[-1]))
                {
                    // Go back one position and return the current character
                    gbump(-1);
                    return *gptr();
                }

                // Otherwise return eof
                return EOF;
            }

            // Syncs output buffer
            int sync_output()
            {
                int size = pptr() - pbase();
                if (send(sock, pbase(), size, 0) != size)
                {
                    return EOF;
                }
                pbump(-size);
                return size;
            }

            // Syncs input buffer
            int sync_input()
            {
                int discard = egptr() - gptr();

                // If the get pointer is at end, the stream is already synchronized
                if (discard == 0)
                {
                    return 0;
                }

                // Otherwise let's discard the characters
                int size = gptr() - eback();
                if (size > putback_size)
                {
                    size = putback_size;
                }

                // Copy current characters to putback area
                std::memcpy(buffer_in + (putback_size - size), gptr() - size, size);

                // Set new input pointers
                setg(buffer_in + (putback_size - size), buffer_in + putback_size, buffer_in + putback_size);

                return discard;
            }

            // Flushes the output buffer
            virtual int sync()
            {
                if (sync_output() == EOF || sync_input() == EOF)
                {
                    return -1;
                }
                return 0;
            }
    };

    class sockstream : public std::iostream
    {
        protected:
            sockbuf buf;

        public:
            sockstream() : std::iostream(&buf)
            {
            }

            sockstream(const char *host, unsigned short port) : std::iostream(&buf)
            {
                open(host, port);
            }

            sockstream(const char *host, const char *port) : std::iostream(&buf)
            {
                open(host, port);
            }

            virtual ~sockstream()
            {
            }

            sockbuf *rdbuf() const
            {
                return (sockbuf*)&buf;
            }

            void open(const char *host, unsigned short port)
            {
                if (buf.open(host, port) == 0)
                {
                    setstate(std::ios_base::failbit);
                }
            }

            void open(const char *host, const char *port)
            {
                open(host, std::atoi(port));
            }

            void close()
            {
                if (buf.close() == 0)
                {
                    setstate(std::ios_base::failbit);
                }
            }

            bool is_open()
            {
                return buf.is_open();
            }
    };
}

#endif

