package main
import (
"context"
"errors"
"fmt"
"github.com/fsnotify/fsnotify"
"github.com/jochenvg/go-udev"
"io/ioutil"
"os"
"os/exec"
"os/signal"
"regexp"
"strings"
"syscall"
"time"
)
type Logger interface {
Log(...interface{})
LogErr(...interface{})
}
type FuncLogger struct {
funcName string
}
func (l FuncLogger) Log(args ...interface{}) {
fmt.Printf("%s: ", l.funcName)
fmt.Println(args...)
}
func (l FuncLogger) LogErr(args ...interface{}) {
fmt.Fprintf(os.Stderr, "%s: ", l.funcName)
fmt.Fprintln(os.Stderr, args...)
}
type ConfigLine struct {
fst int
last int
line string
}
type MyPathError struct {
id string
file string
fun string
}
func (e *MyPathError) Error() string {
switch e.id {
case "NOTREGULAR":
return fmt.Sprintf("%s: %s: not a regular file", e.fun, e.file)
}
panic("Error using MyPathError, go check your code dummy")
}
// finding exact line with error is a bit too hard for me if I want
// to implement line splitting using \, so I'll just repost first and
// last line
type MyParseError struct {
fst int
last int
info string
file string
}
func (e *MyParseError) Error() string {
str := fmt.Sprintf("Synax error in file: %s", e.file)
if (e.fst == e.last) {
str = fmt.Sprintf("%s: line %d", str, e.fst)
} else {
str = fmt.Sprintf("%s: lines %d-%d", str, e.fst, e.last)
}
str = fmt.Sprintf("%s: %s", str, e.info)
return str
}
// Full udev rule spec isn't supported, just ACTION, SUBSYSTEM,
// ATTR(at least somewhat) and RUN
// Everything else I don't need so suck it up
type Rule struct {
action string
subsystem string
attr map[string]string
run []string
}
var rules []Rule
func main() {
var exitCode int = 0
defer func() {
os.Exit(exitCode)
}()
log := FuncLogger{funcName: "main"}
if (os.Getuid() == 0) {
log.LogErr("refusing to run as root")
exitCode = 1
return
}
var configDir string
if val,ok := os.LookupEnv("XDG_CONFIG_HOME"); ok {
configDir = fmt.Sprintf("%s/udev-worker", val)
} else if val,ok := os.LookupEnv("HOME"); ok {
configDir = fmt.Sprintf("%s/.config/udev-worker", val)
} else {
log.LogErr("HOME environmental variable is not set")
exitCode = 1
return
}
err := os.MkdirAll(configDir, 0777) // umask is not my responsibility
if (err != nil) {
log.LogErr(err)
exitCode = 1
return
}
lock,logger,err := acquireLock(fmt.Sprintf("%s/lock", configDir))
if err == syscall.EWOULDBLOCK {
log.Log("Already running")
return
} else if err != nil {
logger.LogErr(err)
exitCode = 1
return
}
defer releaseLock(lock)
logger,err = compileConfigDir(configDir)
if err != nil {
logger.LogErr(err)
exitCode = 1
return
}
watch := make(chan int, 10)
watchErr := make(chan error, 10)
watchLog := make(chan Logger, 10)
go fsWatch(configDir, watch, watchErr, watchLog)
udevCtx := udev.Udev{}
mon := udevCtx.NewMonitorFromNetlink("udev")
ctx,cancel := context.WithCancel(context.Background())
defer cancel()
devchan,errchan,err := mon.DeviceChan(ctx)
signalChan := make(chan os.Signal, 10)
signal.Notify(signalChan, syscall.SIGINT)
signal.Notify(signalChan, syscall.SIGTERM)
signal.Notify(signalChan, syscall.SIGHUP)
log.Log("entering mainloop")
for {
select {
case dev := <-devchan:
logger,err := handleDevice(dev, rules)
if err != nil {
logger.LogErr(err)
}
case deverr := <-errchan:
log.LogErr("a udev error, lmao", deverr)
exitCode = 1
return
case sig := <-signalChan:
log.Log("exiting due to signal", sig)
return
case werr := <-watchErr:
logger := <-watchLog
logger.LogErr(werr)
exitCode = 1
return
case i := <-watch:
switch i {
case 0:
log.Log("Reconfiguring")
logger,err = compileConfigDir(configDir)
if err != nil {
logger.LogErr(err)
exitCode = 1
return
}
case 1:
log.Log("Config dir has been removed, good bye idiot")
return
}
}
}
}
func readFile (path string) ([]string, Logger, error) {
log := FuncLogger{funcName: "readFile"}
sb,err := os.Stat(path)
if err != nil {
return []string{}, log, err
} else if !sb.Mode().IsRegular() {
return []string{}, log, &MyPathError{id: "NOTREGULAR", file: path, fun: "os.Stat"}
}
file,err := os.Open(path)
if err != nil {
return []string{},log,err
}
defer file.Close()
bytes,err := ioutil.ReadAll(file)
if err != nil {
return []string{},log,err
}
str := string(bytes)
strs := strings.Split(strings.TrimSpace(str), "\n")
for i,val := range strs {
strs[i] = strings.TrimSpace(val)
}
return strs,log,nil
}
// lines must already be space trimmed
func preParseConfigFile(lines []string) []ConfigLine {
// many passes because i'm dumb
var i int
var cfg []ConfigLine = make([]ConfigLine, len(lines), len(lines))
for i,_ = range lines {
cfg[i] = ConfigLine{fst: i+1, last: i+1, line: lines[i]}
}
isComment := func (str string) bool {
return str[0] == '#'
}
isEmpty := func (str string) bool {
return len(str) == 0
}
var l int = len(cfg)
for i = 0; i < l; i++ {
if isEmpty(cfg[i].line) || isComment(cfg[i].line) {
cfg = append(cfg[:i], cfg[i+1:]...)
i--
l--
}
}
glue := func (str string) (string, bool) {
length := len(str)
if str[length-1] == '\\' {
return str[:length-1],true
}
return str,false
}
l = len(cfg)
var val string
var ok bool
for i = 0; i < l - 1; i++ {
if val,ok = glue(cfg[i].line); ok {
pre := cfg[:i]
now := cfg[i]
now.last = cfg[i+1].last
now.line = string(append([]byte(val), []byte(cfg[i+1].line)...))
post := cfg[i+2:]
cfg = append(pre, now)
cfg = append(cfg, post...)
l--
i--
}
}
l = len(cfg)
if val,ok = glue(cfg[l-1].line); ok {
cfg[l-1].line = val
}
return cfg
}
func contains(s string, ss []string) bool {
for _,val := range ss {
if val == s {
return true
}
}
return false
}
// one rule per line
func parseRule(cfgline ConfigLine, file string) (Rule,Logger,error) {
log := FuncLogger{funcName: "parseRule"}
var e *MyParseError = &MyParseError{file: file, fst: cfgline.fst, last: cfgline.last}
rule := Rule{attr: make(map[string]string), run: make([]string, 0, 0)}
const (
subsystem = "SUBSYSTEM"
action = "ACTION"
attribute = "ATTR"
run = "RUN"
equal = "=="
assign = "="
add = "+="
)
var keyword string
var attrKey string
var operation string
var inquotes string
var unquotedOK bool
var iter int
var sth int
line := cfgline.line
for {
if strings.HasPrefix(line, subsystem) {
keyword = subsystem
line = line[len(subsystem):]
} else if strings.HasPrefix(line, action) {
keyword = action
line = line[len(action):]
} else if strings.HasPrefix(line, run) {
keyword = run
line = line[len(run):]
} else if strings.HasPrefix(line, attribute) {
keyword = attribute
line = line[len(attribute):]
if line[0] != '{' {
e.info = "attribute key required"
return rule,log,e
}
line = line[1:]
var found bool = false
for iter = 0; iter < len(line); iter++ {
if line[iter] == '}' {
found = true
attrKey = line[:iter]
line = line[iter+1:]
break
}
}
if !found {
e.info = "unterminated bracket in ATTR"
return rule,log,e
}
} else {
e.info = fmt.Sprintf("unknown keyword, first 10 characters are \"%s\"",
line[:10])
return rule,log,e
}
if len(line) == 0 {
e.info = "premature end of line"
return rule,log,e
}
if strings.HasPrefix(line, equal) {
operation = equal
line = line[len(equal):]
} else if strings.HasPrefix(line, assign) {
operation = assign
line = line[len(assign):]
} else if strings.HasPrefix(line, add) {
operation = add
line = line[len(add):]
} else {
e.info = fmt.Sprintf("unknown operation, first 10 characters are \"%s\"",
line[:10])
return rule,log,e
}
if len(line) == 0 {
e.info = "premature end of line"
return rule,log,e
}
unquotedOK = false
if line[0] != '\'' && line[0] != '"' {
e.info = "unquoted string"
return rule,log,e
}
quote := line[0]
sth = len(line)
for iter = 1; iter < sth; iter++ {
if line[iter] == '\\' && iter == sth - 1 {
e.info = "premature end of line, unterminated quote"
return rule,log,e
} else if line[iter] == '\\' && line[iter + 1] == '\\' {
line = string(append([]byte(line[:iter]), []byte(line[iter+1:])...))
sth--
} else if line[iter] == '\\' && line[iter + 1] == quote {
line = string(append([]byte(line[:iter]), []byte(line[iter+1:])...))
sth--
} else if line[iter] == '\\' {
e.info = "unknown escape sequence"
return rule,log,e
} else if line[iter] == quote {
inquotes = line[1:iter]
line = line[iter+1:]
unquotedOK = true
break
}
}
if !unquotedOK {
e.info = "improperly quoted string"
return rule,log,e
} else if len(line) > 0 && line[0] != ',' {
e.info = "unrecognized token after quoted string, expected a comma"
return rule,log,e
} else if line == "," {
e.info = "premature end of line"
return rule,log,e
} else if len(line) > 0 {
line = line[1:]
}
if (keyword == subsystem ||
keyword == action ||
keyword == attribute) &&
operation != equal {
e.info = fmt.Sprintf("%s does not support %s", keyword, operation)
return rule,log,e
}
if keyword == run && operation == equal {
e.info = fmt.Sprintf("%s does not suport %s", keyword, operation)
return rule,log,e
}
switch keyword {
case subsystem:
if !contains(inquotes, []string{"usb","block"}) {
e.info = fmt.Sprintf("subsystem %s is not supported", inquotes)
return rule,log,e
}
rule.subsystem = inquotes
case action:
if !contains(inquotes, []string{"add", "unbind", "remove", "change", "bind"}) {
e.info = fmt.Sprintf("action %s is not supported", inquotes)
return rule,log,e
}
rule.action = inquotes
case attribute:
rule.attr[attrKey] = inquotes
case run:
switch operation {
case add:
rule.run = append(rule.run, inquotes)
case assign:
rule.run = make([]string, 0, 0)
rule.run = append(rule.run, inquotes)
}
}
if len(line) == 0 {
break
}
}
return rule,log,nil
}
func compileConfigDir (configDir string) (Logger, error) {
log := FuncLogger{funcName: "compileConfigDir"}
dir,err := os.Open(configDir)
if err != nil {
return log,err
}
defer dir.Close()
entries,err := dir.ReadDir(-1)
if err != nil {
return log,err
}
var l int = len(entries)
var i int
pattern := "^[a-zA-Z0-9_]+.rules$"
reg := regexp.MustCompile(pattern)
for i = 0; i < l; i++ {
if !entries[i].Type().IsRegular() {
entries = append(entries[:i], entries[i+1:]...)
i--
l--
}
if !reg.MatchString(entries[i].Name()) {
entries = append(entries[:i], entries[i+1:]...)
i--
l--
}
}
rules = make([]Rule, 0, 0)
var rule Rule
var rulefile string
for _,val := range entries {
rulefile = fmt.Sprintf("%s/%s", configDir, val.Name())
lines,logger,err := readFile(rulefile)
if err != nil {
return logger,err
}
preParsed := preParseConfigFile(lines)
for _,pre := range preParsed {
rule,logger,err = parseRule(pre, rulefile)
if err != nil {
return log,err
}
rules = append(rules,rule)
}
}
return log,nil
}
func handleDevice (dev *udev.Device, rules []Rule) (Logger,error) {
log := FuncLogger{funcName: "handleDevice"}
var breakOuter bool = false
var attrpath string
var devpath string
var attrval []byte
var cmd *exec.Cmd
for _,rule := range rules {
if rule.subsystem != "" && rule.subsystem != dev.Subsystem() {
continue
}
if rule.action != "" && rule.action != dev.Action() {
continue
}
breakOuter = false
devpath = fmt.Sprintf("/sys/%s", dev.Devpath())
for key,val := range rule.attr {
attrpath = fmt.Sprintf("%s/%s", devpath, key)
f,err := os.Open(attrpath)
if err != nil {
return log,err
}
sb,err := f.Stat()
if err != nil {
f.Close()
return log,err
}
if !sb.Mode().IsRegular() {
log.LogErr(fmt.Sprintf("%s: is not a regular file", attrpath))
f.Close()
continue
}
attrval,err = ioutil.ReadAll(f)
if errors.Is(err, os.ErrPermission) {
f.Close()
return log,err
}
attrval = []byte(strings.TrimSpace(string(attrval)))
if string(attrval) != val {
f.Close()
breakOuter = true
break
}
}
if breakOuter {
continue
}
for _,val := range rule.run {
cmd = exec.Command("/bin/sh")
cmd.Args = append(cmd.Args, "-c", fmt.Sprintf("exec %s", val))
log.Log(fmt.Sprintf("executing '%s' in /bin/sh", val))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
return log,err
}
go cmd.Wait()
//var eerr *exec.ExitError
//err := cmd.Run()
//if errors.As(err, &eerr) {
// continue
//} else if err != nil {
// return log,err
//}
}
}
return log,nil
}
func acquireLock(lockfile string) (*os.File, Logger, error) {
log := FuncLogger{funcName: "acquireLock"}
file, err := os.OpenFile(lockfile, os.O_RDONLY|os.O_CREATE, 0666)
if err != nil {
return nil,log,err
}
err = syscall.Flock(int(file.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err != nil {
file.Close()
return nil,log,err
}
return file,log,nil
}
func releaseLock(file *os.File) {
syscall.Flock(int(file.Fd()), syscall.LOCK_UN)
file.Close()
}
func fsWatch(file string, ch chan int, errch chan error, logch chan Logger) {
log := FuncLogger{funcName: "fsWatch"}
watcher, err := fsnotify.NewWatcher()
if err != nil {
errch <- err
logch <- log
return
}
defer watcher.Close()
err = watcher.Add(file)
if err != nil {
errch <- err
logch <- log
return
}
var timerPresent bool = false
delay := time.Second * 10
timer := time.NewTimer(delay)
timer.Stop()
defer func () {
if timerPresent {
timer.Stop()
}
}()
for {
select {
case <-timer.C:
ch <- 0
timer.Stop()
case fsevent,ok := <-watcher.Events:
if !ok {
//wtf
continue
}
if fsevent.Has(fsnotify.Write) {
log.Log(fmt.Sprintf("Detected write to %s", file))
if timerPresent {
timer.Reset(delay)
} else {
timer = time.NewTimer(delay)
timerPresent = true
}
} else if fsevent.Has(fsnotify.Remove) && fsevent.Name == file {
ch <- 1
return
}
case fserror,ok := <-watcher.Errors:
if !ok {// wtf
continue
}
log.LogErr(fserror)
// ignoring because wtf
}
}
}