1
0
mirror of https://github.com/mgerb/go-discord-bot synced 2026-01-11 09:32:50 +00:00

convert to chi web framework - add tls support

This commit is contained in:
2017-07-20 21:34:36 -05:00
parent b915ec065a
commit 4baad6b025
6 changed files with 178 additions and 62 deletions

View File

@@ -13,6 +13,12 @@ NOTE: Currently the binaries in the release package only run on linux. Check the
- add your bot token and preferred upload password (leave as is for no password) - add your bot token and preferred upload password (leave as is for no password)
- run the bot with `./bot` (you may need to use sudo if you leave it on port 80) - run the bot with `./bot` (you may need to use sudo if you leave it on port 80)
## Flags
> -p, run in production mode
> -tls, run with auto tls
### NOTE ### NOTE
If you get a permissions error with ffmpeg on mac or linux: If you get a permissions error with ffmpeg on mac or linux:

View File

@@ -2,15 +2,19 @@ package config
import ( import (
"encoding/json" "encoding/json"
"flag"
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
) )
// Variables used for command line parameters // Variables used for command line parameters
var Config configStruct var (
Config configFile
Flags configFlags
)
type configStruct struct { type configFile struct {
Token string `json:"Token"` Token string `json:"Token"`
BotPrefix string `json:"BotPrefix"` //prefix to use for bot commands BotPrefix string `json:"BotPrefix"` //prefix to use for bot commands
SoundsPath string `json:"SoundsPath"` SoundsPath string `json:"SoundsPath"`
@@ -18,8 +22,21 @@ type configStruct struct {
ServerAddr string `json:"ServerAddr` ServerAddr string `json:"ServerAddr`
} }
type configFlags struct {
Prod bool
TLS bool
}
// Init -
func Init() { func Init() {
parseConfig()
parseFlags()
}
func parseConfig() {
log.Println("Reading config file...") log.Println("Reading config file...")
file, e := ioutil.ReadFile("./config.json") file, e := ioutil.ReadFile("./config.json")
@@ -36,5 +53,23 @@ func Init() {
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
}
func parseFlags() {
Flags.Prod = false
Flags.TLS = false
prod := flag.Bool("p", false, "Run in production")
tls := flag.Bool("tls", false, "Use TLS")
flag.Parse()
Flags.Prod = *prod
Flags.TLS = *tls
if *prod {
log.Println("Running in production mode")
}
} }

View File

@@ -0,0 +1,32 @@
package response
import (
"encoding/json"
"net/http"
)
var (
DefaultUnauthorized = []byte("Unauthorized.")
DefaultInternalError = []byte("Internal error.")
)
// JSON - marshals the provided interface and returns JSON to client
func JSON(w http.ResponseWriter, content interface{}) {
output, err := json.Marshal(content)
if err != nil {
ERR(w, http.StatusInternalServerError, []byte("Internal error."))
return
}
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(output)
}
// ERR - send error response
func ERR(w http.ResponseWriter, status int, content []byte) {
w.WriteHeader(status)
w.Write(content)
}

View File

@@ -1,13 +1,14 @@
package handlers package handlers
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"strings" "strings"
"net/http"
"github.com/mgerb/chi_auth_server/response"
"github.com/mgerb/go-discord-bot/server/config" "github.com/mgerb/go-discord-bot/server/config"
"github.com/valyala/fasthttp"
) )
var soundList []sound var soundList []sound
@@ -18,27 +19,21 @@ type sound struct {
Extension string `json:"extension"` Extension string `json:"extension"`
} }
func SoundList(ctx *fasthttp.RequestCtx) { // SoundList -
func SoundList(w http.ResponseWriter, r *http.Request) {
if len(soundList) < 1 { if len(soundList) < 1 {
err := PopulateSoundList() err := PopulateSoundList()
if err != nil { if err != nil {
ctx.Error(err.Error(), 400) response.ERR(w, http.StatusInternalServerError, []byte(response.DefaultInternalError))
return return
} }
} }
response, err := json.Marshal(soundList) response.JSON(w, soundList)
if err != nil {
ctx.Error("Error marshaling json", 400)
return
}
ctx.SetContentType("application/json")
ctx.Write(response)
} }
// PopulateSoundList -
func PopulateSoundList() error { func PopulateSoundList() error {
fmt.Println("Populating sound list.") fmt.Println("Populating sound list.")

View File

@@ -4,27 +4,32 @@ import (
"io" "io"
"os" "os"
"net/http"
"github.com/mgerb/chi_auth_server/response"
"github.com/mgerb/go-discord-bot/server/config" "github.com/mgerb/go-discord-bot/server/config"
"github.com/valyala/fasthttp"
) )
func FileUpload(ctx *fasthttp.RequestCtx) { func FileUpload(w http.ResponseWriter, r *http.Request) {
password := ctx.FormValue("password")
password := r.FormValue("password")
if string(password) != config.Config.UploadPassword { if string(password) != config.Config.UploadPassword {
ctx.Error("Invalid password.", 400) response.ERR(w, http.StatusInternalServerError, []byte("Invalid password."))
return return
} }
file, err := ctx.FormFile("file") file, header, err := r.FormFile("file")
if err != nil { if err != nil {
ctx.Error("Error reading file.", 400) response.ERR(w, http.StatusInternalServerError, []byte("Error reading file."))
return return
} }
src, err := file.Open() defer file.Close()
src, err := header.Open()
if err != nil { if err != nil {
ctx.Error("Error opening file.", 400) response.ERR(w, http.StatusInternalServerError, []byte("Error opening file."))
return return
} }
@@ -36,21 +41,21 @@ func FileUpload(ctx *fasthttp.RequestCtx) {
} }
// check if file already exists // check if file already exists
if _, err := os.Stat(config.Config.SoundsPath + file.Filename); err == nil { if _, err := os.Stat(config.Config.SoundsPath + header.Filename); err == nil {
ctx.Error("File already exists.", 400) response.ERR(w, http.StatusInternalServerError, []byte("File already exists."))
return return
} }
dst, err := os.Create(config.Config.SoundsPath + file.Filename) dst, err := os.Create(config.Config.SoundsPath + header.Filename)
if err != nil { if err != nil {
ctx.Error("Error creating file.", 400) response.ERR(w, http.StatusInternalServerError, []byte("Error creating file."))
return return
} }
defer dst.Close() defer dst.Close()
if _, err = io.Copy(dst, src); err != nil { if _, err = io.Copy(dst, src); err != nil {
ctx.Error("Error writing file.", 400) response.ERR(w, http.StatusInternalServerError, []byte("Error writing file."))
return return
} }
@@ -58,9 +63,9 @@ func FileUpload(ctx *fasthttp.RequestCtx) {
err = PopulateSoundList() err = PopulateSoundList()
if err != nil { if err != nil {
ctx.Error("File uploaded, but error populating sound list.", 400) response.ERR(w, http.StatusInternalServerError, []byte("Error populating sound list."))
return return
} }
ctx.Success("application/json", []byte("Success!")) response.JSON(w, []byte("Success"))
} }

View File

@@ -2,48 +2,91 @@ package webserver
import ( import (
"log" "log"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/buaazp/fasthttprouter" "golang.org/x/crypto/acme/autocert"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/mgerb/go-discord-bot/server/config" "github.com/mgerb/go-discord-bot/server/config"
"github.com/mgerb/go-discord-bot/server/webserver/handlers" "github.com/mgerb/go-discord-bot/server/webserver/handlers"
"github.com/valyala/fasthttp"
) )
func logger(next fasthttp.RequestHandler) fasthttp.RequestHandler { func getRouter() *chi.Mux {
return func(ctx *fasthttp.RequestCtx) {
logger := ctx.Logger() r := chi.NewRouter()
logger.Printf(ctx.RemoteAddr().String())
next(ctx) r.Use(middleware.RequestID)
} r.Use(middleware.DefaultCompress)
}
if !config.Flags.Prod {
func applyMiddleware(handler fasthttp.RequestHandler) fasthttp.RequestHandler { r.Use(middleware.Logger)
newHandler := logger(handler)
return newHandler
}
func registerRoutes(router *fasthttprouter.Router) {
router.GET("/soundlist", handlers.SoundList)
router.PUT("/upload", handlers.FileUpload)
router.ServeFiles("/static/*filepath", "./dist/static")
router.ServeFiles("/sounds/*filepath", config.Config.SoundsPath)
router.NotFound = func(ctx *fasthttp.RequestCtx) {
fasthttp.ServeFile(ctx, "./dist/index.html")
} }
r.Get("/soundlist", handlers.SoundList)
r.Put("/upload", handlers.FileUpload)
workDir, _ := os.Getwd()
FileServer(r, "/static", http.Dir(filepath.Join(workDir, "./dist/static")))
FileServer(r, "/sounds", http.Dir(filepath.Join(workDir, "./sounds")))
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "./dist/index.html")
})
return r
} }
// Start -
func Start() { func Start() {
router := fasthttprouter.New() router := getRouter()
registerRoutes(router) if config.Flags.TLS {
// apply our middleware // start server on port 80 to redirect
handlers := applyMiddleware(router.Handler) go http.ListenAndServe(":80", http.HandlerFunc(redirect))
// start web server // start TLS server
log.Fatal(fasthttp.ListenAndServe(config.Config.ServerAddr, handlers)) log.Fatal(http.Serve(autocert.NewListener(), router))
} else {
// start basic server
http.ListenAndServe(config.Config.ServerAddr, router)
}
}
// FileServer conveniently sets up a http.FileServer handler to serve
// static files from a http.FileSystem.
func FileServer(r chi.Router, path string, root http.FileSystem) {
if strings.ContainsAny(path, "{}*") {
panic("FileServer does not permit URL parameters.")
}
fs := http.StripPrefix(path, http.FileServer(root))
if path != "/" && path[len(path)-1] != '/' {
r.Get(path, http.RedirectHandler(path+"/", 301).ServeHTTP)
path += "/"
}
path += "*"
r.Get(path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fs.ServeHTTP(w, r)
}))
}
// redirect to https
func redirect(w http.ResponseWriter, req *http.Request) {
// remove/add not default ports from req.Host
target := "https://" + req.Host + req.URL.Path
if len(req.URL.RawQuery) > 0 {
target += "?" + req.URL.RawQuery
}
http.Redirect(w, req, target, http.StatusTemporaryRedirect)
} }