diff --git a/src/lxc/lxc_init.c b/src/lxc/lxc_init.c index f2011f9fb..bcdfcc389 100644 --- a/src/lxc/lxc_init.c +++ b/src/lxc/lxc_init.c @@ -92,6 +92,86 @@ static struct arguments my_args = { .shortopts = short_options }; +static void prevent_forking(void) +{ + FILE *f; + char name[PATH_MAX], path[PATH_MAX]; + int ret; + + f = fopen("/proc/self/cgroup", "r"); + if (!f) { + SYSERROR("opening /proc/self/cgroup"); + return; + } + + while (!feof(f)) { + int fd; + + if (2 != fscanf(f, "%*d:%[^:]:%s", name, path)) { + ERROR("didn't scan the right number of things"); + goto out; + } + + if (strcmp(name, "pids")) + continue; + + ret = snprintf(name, sizeof(name), "/sys/fs/cgroup/pids/%s/pids.max", path); + if (ret < 0 || ret >= sizeof(path)) { + ERROR("failed snprintf"); + goto out; + } + + fd = open(name, O_WRONLY); + if (fd < 0) { + SYSERROR("open"); + goto out; + } + + if (write(fd, "1", 1) != 1) + SYSERROR("write"); + + close(fd); + break; + } + +out: + fclose(f); +} + +static void kill_children(pid_t pid) +{ + FILE *f; + char path[PATH_MAX]; + int ret; + + ret = snprintf(path, sizeof(path), "/proc/%d/task/%d/children", pid, pid); + if (ret < 0 || ret >= sizeof(path)) { + ERROR("failed snprintf"); + return; + } + + f = fopen(path, "r"); + if (!f) { + SYSERROR("couldn't open %s", path); + return; + } + + while (!feof(f)) { + pid_t pid; + + if (fscanf(f, "%d ", &pid) != 1) { + ERROR("couldn't scan pid"); + fclose(f); + return; + } + + kill_children(pid); + kill(pid, SIGKILL); + } + + fclose(f); +} + int main(int argc, char *argv[]) { int i, ret; @@ -258,18 +338,28 @@ int main(int argc, char *argv[]) case SIGTERM: if (!shutdown) { shutdown = 1; - ret = kill(-1, SIGTERM); - if (ret < 0) - DEBUG("%s - Failed to send SIGTERM to " - "all children", strerror(errno)); + prevent_forking(); + if (getpid() != 1) { + kill_children(getpid()); + } else { + ret = kill(-1, SIGTERM); + if (ret < 0) + DEBUG("%s - Failed to send SIGTERM to " + "all children", strerror(errno)); + } alarm(1); } break; case SIGALRM: - ret = kill(-1, SIGKILL); - if (ret < 0) - DEBUG("%s - Failed to send SIGKILL to all " - "children", strerror(errno)); + prevent_forking(); + if (getpid() != 1) { + kill_children(getpid()); + } else { + ret = kill(-1, SIGTERM); + if (ret < 0) + DEBUG("%s - Failed to send SIGTERM to " + "all children", strerror(errno)); + } break; default: ret = kill(pid, was_interrupted);