diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index b8eb537..cf601c2 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -166,14 +166,65 @@ func (h *Handler) setHandlers() error { // Handling the shorted URLs, if no one exists, it checks // in the filesystem and sets headers for caching - h.engine.NoRoute(h.handleAccess, func(c *gin.Context) { - c.Header("Vary", "Accept-Encoding") - c.Header("Cache-Control", "public, max-age=2592000") - c.Header("ETag", util.VersionInfo.Commit) - }, gin.WrapH(http.FileServer(FS(false)))) + h.engine.NoRoute( + h.handleAccess, // look up shortcuts + func(c *gin.Context) { // no shortcut found, prep response for FS + c.Header("Vary", "Accept-Encoding") + c.Header("Cache-Control", "public, max-age=2592000") + c.Header("ETag", util.VersionInfo.Commit) + }, + // Pass down to the embedded FS, but let 404s escape via + // the interceptHandler. + gin.WrapH(interceptHandler(http.FileServer(FS(false)), customErrorHandler)), + // not in FS; redirect to root with customURL target filled out + func(c *gin.Context) { + // if we get to this point we should not let the client cache + c.Header("Cache-Control", "no-cache, no-store") + c.Redirect(http.StatusTemporaryRedirect, "/?customUrl="+c.Request.URL.Path[1:]) + }) return nil } +type interceptResponseWriter struct { + http.ResponseWriter + errH func(http.ResponseWriter, int) +} + +func (w *interceptResponseWriter) WriteHeader(status int) { + if status >= http.StatusBadRequest { + w.errH(w.ResponseWriter, status) + w.errH = nil + } else { + w.ResponseWriter.WriteHeader(status) + } +} + +type errorHandler func(http.ResponseWriter, int) + +func (w *interceptResponseWriter) Write(p []byte) (n int, err error) { + if w.errH == nil { + return len(p), nil + } + return w.ResponseWriter.Write(p) +} + +func interceptHandler(next http.Handler, errH errorHandler) http.Handler { + if errH == nil { + errH = customErrorHandler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(&interceptResponseWriter{w, errH}, r) + }) +} + +func customErrorHandler(w http.ResponseWriter, status int) { + // let 404s fall through: the next NoRoute handler will redirect + // them back to the main page with the customURL box filled out. + if status != 404 { + http.Error(w, "error", status) + } +} + // Listen starts the http server func (h *Handler) Listen() error { return h.engine.Run(util.GetConfig().ListenAddr) diff --git a/internal/handlers/public_test.go b/internal/handlers/public_test.go index cfab783..2e616d9 100644 --- a/internal/handlers/public_test.go +++ b/internal/handlers/public_test.go @@ -275,8 +275,8 @@ func TestHandleDeletion(t *testing.T) { t.Fatalf("could not send visit request: %v", err) } fmt.Println(body.URL) - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("expected status: %d; got: %d", http.StatusNotFound, resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status: %d; got: %d", http.StatusOK, resp.StatusCode) } } diff --git a/web/src/Home/Home.js b/web/src/Home/Home.js index ac475b4..a5f1eac 100644 --- a/web/src/Home/Home.js +++ b/web/src/Home/Home.js @@ -10,21 +10,31 @@ import CustomCard from '../Card/Card' import './Home.css' export default class HomeComponent extends Component { + constructor(props) { + super(props); + this.urlParams = new URLSearchParams(window.location.search); + this.state = { + links: [], + usedSettings: this.urlParams.get('customUrl') ? ['custom'] : [], + customID: this.urlParams.get('customUrl') ? this.urlParams.get('customUrl') : '', + showCustomIDError: false, + expiration: null + } + } handleURLChange = (e, { value }) => this.url = value handlePasswordChange = (e, { value }) => this.password = value handleCustomExpirationChange = expire => this.setState({ expiration: expire }) handleCustomIDChange = (e, { value }) => { - this.customID = value + this.setState({customID: value}) util.lookupEntry(value, () => this.setState({ showCustomIDError: true }), () => this.setState({ showCustomIDError: false })) } - onSettingsChange = (e, { value }) => this.setState({ usedSettings: value }) - - state = { - links: [], - usedSettings: [], - showCustomIDError: false, - expiration: null + onSettingsChange = (e, { value }) => { + this.setState({ usedSettings: value }) } + + + + componentDidMount() { this.urlInput.focus() } @@ -32,7 +42,7 @@ export default class HomeComponent extends Component { if (!this.state.showCustomIDError) { util.createEntry({ URL: this.url, - ID: this.customID, + ID: this.state.customID, Expiration: this.state.usedSettings.includes("expire") && this.state.expiration ? this.state.expiration.toISOString() : undefined, Password: this.state.usedSettings.includes("protected") && this.password ? this.password : undefined }, r => this.setState({ @@ -56,13 +66,18 @@ export default class HomeComponent extends Component { return (