#include <cstdlib>
#include <iostream>
#include <fstream>
#include <sstream>
#include <thread>
#include <stdexcept>
#include <string>
#include <list>
#include <random>
#include <chrono>
#include <mutex>
#include <condition_variable>
#include <map>
using std::string;
#include <cleansocks.h>
#include <cleanip.h>
#include <cleanbuf.h>
using namespace cleansocks;
#include "sha256.h"
// Port numbers.
const int DFLTPORT = 45100;
// Passwords.
std::map<string,string> pwds;
// Usage widget.
string pgm;
void usage(string msg = "", bool done = true)
{
if(msg != "") std::cerr << msg << std::endl;
std::cerr << "Usage: " << pgm << " dbase [ pout [ pin ] ]" << std::endl;
if(done) exit(1);
}
// Load the passwords
void loadpwd(string fn)
{
std::ifstream in(fn);
if(!in) usage("Cannot read " + fn);
string ln;
while(std::getline(in, ln)) {
if(ln[0] == '#') continue;
int splitpt = ln.find(':');
pwds[ln.substr(0,splitpt)] = ln.substr(splitpt+1);
}
}
// List of active reader connections.
class listrecd {
public:
listrecd(TCPsocket s): lner(s), usect(0) { }
TCPsocket get() { ++usect; return lner; }
int put() { return --usect; }
private:
TCPsocket lner;
int usect;
};
std::list<listrecd> listeners;
std::mutex listlock;
// Listen for reader connections.
void reader_listen(int msgout)
{
// Catch sender connections. Thread authorizes, then processes
// any messages sent until the connection closes.
TCPsocket listener;
try {
bind(listener, IPendpoint(IPaddress::any(), msgout));
listen(listener);
} catch(socket_error &e) {
std::cout << "Error on listener bind/listen: " << e.what()
<< std::endl;
close(listener);
return;
}
while(true) {
// Wait for contact from a client the specified port number.
// Also prints a message when clients are accpeted.
IPendpoint rmt;
TCPsocket c;
try {
c = accept(listener, rmt);
} catch(socket_error &e) {
std::cout << "Error on listener accept: " << e.what()
<< std::endl;
continue;
}
std::cout << "Accepted receiving client " << rmt << std::endl;
{
std::lock_guard mx(listlock);
listeners.push_back(listrecd(c));
}
std::cout << "Client " << rmt << " added to listener list."
<< std::endl;
}
}
/*
* Generate a nonce for sending to a client.
*/
static const int NONSIZE = 16;
// Could use random_device instead of default.
static std::default_random_engine generator
(std::chrono::system_clock::now().time_since_epoch().count());
static std::uniform_int_distribution<int> distribution(0,0xffff);
string generate_nonce()
{
string ret;
ret.reserve(NONSIZE);
for(int i = 0; i < NONSIZE; i+=2)
{
char buf[5];
int r = rand();
sprintf(buf, "%04x", distribution(generator));
ret += buf;
}
return ret;
}
/*
* Handle sender connections. Started when c is connected, sends a nonce,
* checks the response. Closes and exits if this fails. If it succeeds,
* the thread continues to read the port socket and forward it messages to
* each listener socket.
*/
void handle_sender(TCPsocket _c, const IPendpoint cli)
{
buffered_socket c(_c);
try {
// Send a nonce to the prospective client.
string nonce = generate_nonce();
send(c, nonce + "\r\n");
// Read an authorization response.
char buf[1024];
int ct = recvln(c, buf, sizeof buf);
// Get the parts
string user, senthash;
std::stringstream(string(buf,ct)) >> user >> senthash;
// Attempt to verify.
auto pwe = pwds.find(user);
if(pwe == pwds.end()) {
send(c, string("-Unknown user ")+user+"\r\n");
close(c);
std::cout << "Unknown user " << cli << std::endl;
return;
}
sha256 truehash(nonce+pwe->second);
if(truehash.getx() != senthash) {
send(c, string("-Bad password \r\n"));
close(c);
std::cout << "Bad password " << user << " "
<< cli << std::endl;
return;
}
send(c,string("+OK\r\n"));
std::cout << "Authorized " << cli << std::endl;
// Read and relay messages.
while(true) {
// Get the message. Check termination.
int ct = recvln(c, buf, sizeof buf);
if(ct == 0) {
std::cout << "Client " << user << " " << cli
<< " terminated." << std::endl;
break;
}
string msg(buf,ct);
if(ct < 2 || msg.find("\r\n") != ct - 2) {
std::cout << "Badly terminated message "
<< user << " " << cli << std::endl;
continue;
}
msg = user + ": " + msg;
/*
* Forward to the listeners.
*/
// Use a critical section to get the first listener
// from the list, and record if one exists. If there
// is one, this also runs get() to notify other threads
// that the record is in use.
bool done = true;
TCPsocket lstner;
std::list<listrecd>::iterator scan;
{
// Use CS to initialize the scanner, see if
// there's anything in the list to enter the
// loop for, and (if so), get the socket to
// send to. The get method also updates the
// refct on the listener record so only the
// last thread to want to delete it can
// actually remove it from the list.
std::lock_guard lck(listlock);
scan = listeners.begin();
done = (scan == listeners.end());
if(!done) lstner = scan->get();
}
// This goes through the list to handle each listener.
// It starts if the earlier CS found and got a record.
while(!done) {
// Send a mesage. The value n indicates
// success or error.
int n = -1;
try {
n = send(lstner, msg);
} catch(socket_error &e) {
n = -1;
}
// This CS gets the next listener (if any),
// recording if there is one. Closes the
// socket if the send failed and no other
// thread is using it.
int oldct;
TCPsocket oldlstnr;
{
// In the CS, put the listener record
// back since we're done. If it
// failed on the send, and is not in
// use elsewhere, remove it.
std::lock_guard<std::mutex>
lck(listlock);
auto oldscan = scan++;
oldct = oldscan->put();
if(n <= 0 && oldct == 0) {
listeners.erase(oldscan);
}
done = (scan == listeners.end());
oldlstnr = lstner;
if(!done) lstner = scan->get();
}
// Close the socket if we removed it.
if(n <= 0 && oldct == 0) {
close(oldlstnr);
std::cout << "Listener dropped"
<< std::endl;
}
}
}
} catch(socket_error &boom) {
std::cout << "Connection to " << cli << " broken:"
<< boom.what() << std::endl;
}
close(c);
}
int main(int argc, char **argv)
{
pgm = argv[0];
++argv; --argc;
// Port numbers
int msgout = DFLTPORT, msgin = DFLTPORT+1;
if(argc < 1 || argc > 3) usage();
loadpwd(argv[0]);
try {
if(argc > 1) {
msgout = std::stoi(argv[1]);
if(argc > 2) {
msgin = std::stoi(argv[2]);
} else {
msgin = msgout + 1;
}
}
} catch(std::invalid_argument e) {
usage();
}
// The thread to catch reader connections.
std::thread rdrs(reader_listen, msgout);
// Catch sender connections. Thread authorizes, then processes
// any messages sent until the connection closes.
TCPsocket listener;
try {
bind(listener, IPendpoint(IPaddress::any(), msgin));
listen(listener);
} catch(socket_error &e) {
std::cout << "Error binding listener port: " << e.what()
<< std::endl;
}
while(true) {
try {
// Wait for contact from a client on specified port.
// Also prints a message when clients are accpeted.
IPendpoint rmt;
TCPsocket c = accept(listener, rmt);
std::cout << "Accepted sending client " << rmt
<< std::endl;
std::thread sndthd(handle_sender, c, rmt);
sndthd.detach();
std::cout << "Client thread started "
<< rmt << std::endl;
} catch(socket_error &e) {
std::cout << "Error attempting to accept sender: "
<< e.what() << std::endl;
}
}
rdrs.join();
}