src/: +streamfer*
[nethome.git] / src / streamfer-server.C
diff --git a/src/streamfer-server.C b/src/streamfer-server.C
new file mode 100644 (file)
index 0000000..dedb9f3
--- /dev/null
@@ -0,0 +1,207 @@
+#include "safeio.h"
+#include "socket.h"
+#include "stringf.h"
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <dirent.h>
+#include <poll.h>
+#include <climits>
+#include <cstdlib>
+#include <csignal>
+
+// https://stackoverflow.com/a/8615450/2995591
+#include <glob.h> // glob(),globfree()
+#include <cstring> // memset()
+#include <vector>
+#include <string>
+
+static std::vector<std::string> cxxglob(const std::string pattern) {
+  glob_t glob_result;
+  memset(&glob_result,0,sizeof(glob_result));
+  int return_value=glob(pattern.c_str(),GLOB_TILDE,NULL,&glob_result);
+  if (return_value)
+    fatal("glob() failed with return_value %s",return_value);
+  vector<string> filenames;
+  filenames.reserve(glob_result.gl_pathc);
+  for (size_t i = 0; i < glob_result.gl_pathc; ++i)
+     filenames.push_back(string(glob_result.gl_pathv[i]));
+  globfree(&glob_result);
+  return filenames;
+}
+
+// FIXME: Use C++17
+static bool fd_is_open(const char *execname,const char *fn) {
+  const char slashproc[]("/proc");
+  DIR *dir(opendir(slashproc));
+  if (!dir)
+    fatal("Cannot opendir %s: %m",slashproc);
+  bool retval(false);
+  for (;;) {
+    errno=0;
+    const struct dirent *de=readdir(dir);
+    if (!de) {
+      if (errno)
+       fatal("Cannot readdir %s: %m",slashproc);
+      break;
+    }
+    if (!isdigit(de->d_name[0]))
+      continue;
+
+    char buf[PATH_MAX];
+    ssize_t got(readlinkat(dirfd(dir),stringf("%s/exe",de->d_name).c_str(),buf,sizeof(buf)));
+    if (got==-1||got==sizeof(buf))
+      continue;
+    buf[got]=0;
+    char *s=strrchr(buf,'/');
+    if (!s)
+      continue;
+    if (strcmp(s+1,execname)!=0)
+      continue;
+
+    string procpidfd(stringf("/proc/%s/fd",de->d_name));
+    DIR *fddir(opendir(procpidfd.c_str()));
+    if (!fddir)
+      fatal("Cannot opendir %s: %m",procpidfd.c_str());
+    for (;;) {
+      errno=0;
+      const struct dirent *de=readdir(fddir);
+      if (!de) {
+       if (errno)
+         fatal("Cannot readdir %s: %m",procpidfd.c_str());
+       break;
+      }
+      if (!isdigit(de->d_name[0]))
+       continue;
+      char buf[PATH_MAX];
+      ssize_t got(readlinkat(dirfd(fddir),de->d_name,buf,sizeof(buf)));
+      if (got==-1||got==sizeof(buf))
+       continue;
+      buf[got]=0;
+      if (strcmp(buf,fn)==0) {
+       retval=true;
+       break;
+      }
+    }
+    if (closedir(fddir))
+      fatal("Cannot closedir %s: %m",procpidfd.c_str());
+    if (retval)
+      break;
+  }
+  if (closedir(dir))
+    fatal("Cannot closedir %s: %m",slashproc);
+  return retval;
+}
+
+int main(int argc,char **argv) {
+  static struct sigaction sigchld;
+  sigchld.sa_handler=SIG_DFL;
+  sigchld.sa_flags=SA_NOCLDWAIT;
+  int err(sigaction(SIGCHLD,&sigchld,nullptr));
+  assert(!err);
+
+  if (argc!=1+2&&argc!=1+3)
+    fatal("streamfer-server [<listen-host>:]<listen-port> <prefix> [follow-fd-of-executable-basename]");
+  string prefix;
+  if (argc>=1+2&&*argv[2])
+    prefix=argv[2];
+  const char *execname(nullptr);
+  if (argc>=1+3)
+    execname=argv[3];
+  int listen_fd(socket_bind(argv[1]));
+  int client_fd;
+  for (;;) {
+    client_fd=socket_accept(listen_fd,[&](int client_fd,string addr) {
+      warning("%d:%s",client_fd,addr.c_str());
+    });
+    int child(fork());
+    assert(child!=-1);
+    if (!child)
+      break;
+    int err(close(client_fd));
+    assert(!err);
+  }
+  err=close(listen_fd);
+  assert(!err);
+
+  string pattern(read_safe_string(client_fd));
+  std::vector<std::string> matched(cxxglob(pattern));
+  for (size_t ix=0;ix<matched.size()-1;++ix) {
+    const std::string &a(matched[ix  ]);
+    const std::string &b(matched[ix+1]);
+    int err(strcmp(a.c_str(),b.c_str()));
+    if (err>=0)
+      fatal("glob: strcmp(\"%s\",\"%s\")=%d",a.c_str(),b.c_str(),err);
+  }
+  string last(read_safe_string(client_fd));
+  size_t lastix(SIZE_MAX);
+  for (size_t ix=0;ix<matched.size();++ix) {
+    const std::string &member(matched[ix]);
+    if (strcmp(last.c_str(),member.c_str())>0)
+      assert(lastix==SIZE_MAX);
+    else if (lastix==SIZE_MAX)
+      lastix=ix;
+  }
+  if (lastix==SIZE_MAX)
+    fatal("Requested too new file");
+  uint64_t offset;
+  read_safe(client_fd,offset);
+  const string *fnp;
+  int file_fd=-1;
+  struct stat statbuf;
+  for (;lastix<matched.size();file_fd=-1,++lastix) {
+    fnp=&matched[lastix];
+    const string &fn(*fnp);
+    file_fd=open(fn.c_str(),O_RDONLY);
+    if (file_fd==-1) {
+      if (errno!=ENOENT)
+       fatal("Cannot open %s: %m",fn.c_str());
+      continue;
+    }
+    int err(fstat(file_fd,&statbuf));
+    assert(!err);
+    if (offset<(uint64_t)statbuf.st_size)
+      break;
+    if (offset>(uint64_t)statbuf.st_size)
+      warning("File %s has transferred %zu < %zu which is its size",fn.c_str(),(size_t)offset,(size_t)statbuf.st_size);
+    if (lastix==matched.size()-1&&execname)
+      break;
+    err=close(file_fd);
+    assert(!err);
+    offset=0;
+  }
+  if (file_fd==-1) {
+    string empty("");
+    write_safe(client_fd,empty);
+    fatal("No more files to transfer");
+  }
+  const string &fn(*fnp);
+  const char *fn_canon(nullptr);
+  if (!prefix.empty()||execname) {
+    fn_canon=realpath(fn.c_str(),nullptr);
+    if (!fn_canon)
+      fatal("realpath %s: %m",fn.c_str());
+  }
+  if (fn!=last)
+    offset=0;
+  if (!prefix.empty()&&strncmp(prefix.c_str(),fn_canon,prefix.length())!=0)
+    fatal("prefix=\"%s\" realpath=\"%s\"",prefix.c_str(),fn_canon);
+  warning("%s @%zu",fn.c_str(),(size_t)offset);
+  write_safe(client_fd,fn);
+  write_safe(client_fd,statbuf.st_mtim);
+  off_t got(lseek(file_fd,offset,SEEK_SET));
+  assert((uint64_t)got==offset);
+  struct pollfd fds;
+  fds.fd=client_fd;
+  fds.events=POLLIN|POLLPRI|POLLRDHUP;
+  for (;;) {
+    transfer(file_fd,fn.c_str(),client_fd,"client fd");
+    if (!fn_canon||!fd_is_open(execname,fn_canon))
+      break;
+    int err(poll(&fds,1,1000/*ms*/));
+    if (err==-1)
+      fatal("poll client fd: %m");
+    if (err==1) 
+      fatal("poll client fd: revents=0x%x",fds.revents);
+    assert(err==0);
+  }
+}