/**
 * Measure data transferred by a process.
 *
 * Used with LD_PRELOAD.
 */

#define _GNU_SOURCE
#include <dlfcn.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>

static unsigned long long total_sent = 0;
static unsigned long long total_recv = 0;
static time_t start_time;

static void transfcount_init(void);
static void transfcount_atexit(void);
static void transfcount_print_bytes(unsigned long long n);

typedef ssize_t (*send_type)(int sockfd, const void *buf, size_t len,
		int flags);
typedef int (*sendmmsg_type)(int sockfd, struct mmsghdr *msgvec,
		unsigned int vlen, int flags);
typedef ssize_t (*recv_type)(int sockfd, void *buf, size_t len, int flags);
typedef ssize_t (*read_type)(int fd, void *buf, size_t count);
typedef ssize_t (*write_type)(int fd, const void *buf, size_t count);

static send_type orig_send;
static sendmmsg_type orig_sendmmsg;
static recv_type orig_recv;
static read_type orig_read;
static write_type orig_write;


/**
 * send - send a message on a socket.
 */
ssize_t send(int sockfd, const void *buf, size_t len, int flags)
{
	ssize_t bytes;

	transfcount_init();

	/*printf("######## send(sockfd=%i, buf=%p, len=%zi, flags=%x) -> ",
			sockfd, buf, len, flags);*/

	bytes = orig_send(sockfd, buf, len, flags);
	total_sent += bytes;
	/*printf("%lu\n", bytes);*/

	return bytes;
}


/**
 * sendmmsg - send multiple messages on a socket.
 */
int sendmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int vlen,
                    int flags)
{
	int sent, i;

	transfcount_init();

	/*printf("######## sendmmsg(sockfd=%i, msgvec=%p, vlen=%u, flags=%x) -> ",
			sockfd, msgvec, vlen, flags);*/

	sent = orig_sendmmsg(sockfd, msgvec, vlen, flags);
	/*printf("%i\n", sent);*/
	for (i = 0; i < sent; i++)
		total_sent += msgvec[i].msg_len;

	return sent;
}


/**
 * recv - receive a message from a socket.
 */
ssize_t recv(int sockfd, void *buf, size_t len, int flags)
{
	ssize_t bytes;

	transfcount_init();

	/*printf("######## recv(sockfd=%i, buf=%p, len=%zi, flags=%x) -> ",
			sockfd, buf, len, flags);*/

	bytes = orig_recv(sockfd, buf, len, flags);
	total_recv += bytes;
	/*printf("%lu\n", bytes);*/

	return bytes;
}


/**
 * read - read from a file descriptor.
 */
ssize_t read(int fd, void *buf, size_t count)
{
	struct stat statbuf;
	ssize_t bytes;

	transfcount_init();

	fstat(fd, &statbuf);
	if (!S_ISSOCK(statbuf.st_mode))
		return orig_read(fd, buf, count);

	/*printf("######## read(fd=%i, buf=%p, count=%zi) -> ",
			fd, buf, count);*/

	bytes = orig_read(fd, buf, count);
	if (bytes != -1)
		total_recv += bytes;
	/*printf("%lu\n", bytes);*/

	return bytes;
}


/**
 * write - write to a file descriptor.
 */
ssize_t write(int fd, const void *buf, size_t count)
{
	struct stat statbuf;
	ssize_t bytes;

	transfcount_init();

	fstat(fd, &statbuf);
	if (!S_ISSOCK(statbuf.st_mode))
		return orig_write(fd, buf, count);

	/*printf("######## write(fd=%i, buf=%p, count=%zi) -> ",
			fd, buf, count);*/

	bytes = orig_write(fd, buf, count);
	if (bytes != -1)
		total_sent += bytes;
	/*printf("%lu\n", bytes);*/

	return bytes;
}


/**
 * Set up library.
 */
void transfcount_init(void)
{
	static bool init_done = false;

	if (init_done)
		return;

	orig_send = (send_type)dlsym(RTLD_NEXT, "send");
	orig_sendmmsg = (sendmmsg_type)dlsym(RTLD_NEXT, "sendmmsg");
	orig_recv = (recv_type)dlsym(RTLD_NEXT, "recv");
	orig_read = (read_type)dlsym(RTLD_NEXT, "read");
	orig_write = (write_type)dlsym(RTLD_NEXT, "write");

	atexit(transfcount_atexit);

	start_time = time(NULL);

	init_done = true;
}


/**
 * Output transfer statistics.
 */
void transfcount_atexit(void)
{
	time_t end_time = time(NULL);
	time_t duration = end_time - start_time;

	if (total_sent == 0 && total_recv == 0)
		return;

	if (duration == 0)
		duration = 1;

	fprintf(stderr, "[>>> Sent ");
	transfcount_print_bytes(total_sent);
	fprintf(stderr, "B (");
	transfcount_print_bytes(total_sent / duration);
	fprintf(stderr, "B/s) >>>] [<<< Received ");
	transfcount_print_bytes(total_recv);
	fprintf(stderr, "B (");
	transfcount_print_bytes(total_recv / duration);
	fprintf(stderr, "B/s) <<<]\n");
}


/**
 * Print a size as human readable.
 */
void transfcount_print_bytes(unsigned long long n)
{
	if (1024 * 1024 * 1024 <= n) {
		fprintf(stderr, "%.1fGi", (double) n / 1024 / 1024 / 1024);
	} else if (1024 * 1024 <= n) {
		fprintf(stderr, "%.1fMi", (double) n / 1024 / 1024);
	} else if (1024 <= n) {
		fprintf(stderr, "%.1fKi", (double) n / 1024);
	} else {
		fprintf(stderr, "%llu", n);
	}
}

